diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index f20be546..6aed7430 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -78,6 +78,7 @@ compute_pretrends_power, ) from diff_diff.prep import ( + aggregate_survey, aggregate_to_cohorts, balance_panel, create_event_time, @@ -328,6 +329,7 @@ "generate_survey_did_data", "generate_continuous_did_data", "create_event_time", + "aggregate_survey", "aggregate_to_cohorts", "rank_control_units", # Honest DiD sensitivity analysis diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 97543576..c5bdbdb2 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -9,25 +9,30 @@ re-exported here for backward compatibility. """ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd -from diff_diff.utils import compute_synthetic_weights - # Re-export data generation functions from prep_dgp for backward compatibility -from diff_diff.prep_dgp import ( +from diff_diff.prep_dgp import ( # noqa: F401 generate_continuous_did_data, + generate_ddd_data, generate_did_data, - generate_staggered_data, + generate_event_study_data, generate_factor_data, - generate_ddd_data, generate_panel_data, - generate_event_study_data, + generate_staggered_data, generate_staggered_ddd_data, generate_survey_did_data, ) +from diff_diff.survey import ( + ResolvedSurveyDesign, + SurveyDesign, + compute_replicate_if_variance, + compute_survey_if_variance, +) +from diff_diff.utils import compute_synthetic_weights # Constants for rank_control_units _SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar" @@ -1300,3 +1305,434 @@ def trim_weights( result[weight_col] = w return result + + +# --------------------------------------------------------------------------- +# Survey aggregation helpers +# --------------------------------------------------------------------------- + + +def _cell_mean_variance( + y_full: np.ndarray, + full_resolved: ResolvedSurveyDesign, + cell_mask: np.ndarray, + min_n: int, +) -> Tuple[float, float, int, bool]: + """Compute design-based mean and variance of the weighted mean for one cell. + + Uses full-design domain estimation: the influence function is zero-padded + outside the cell, preserving the full strata/PSU structure for variance + estimation. This is the methodologically correct approach for domain + estimation under complex survey designs (Lumley 2004, Section 3.4). + + Parameters + ---------- + y_full : np.ndarray + Outcome values for the full dataset (may contain NaN). + full_resolved : ResolvedSurveyDesign + Full-sample resolved survey design. + cell_mask : np.ndarray + Boolean mask identifying cell members in the full dataset. + min_n : int + Minimum valid observations for design-based variance. Below this + threshold, SRS fallback is used. + + Returns + ------- + mean : float + Design-weighted cell mean. + variance : float + Design-based variance of the cell mean (>= 0). Uses SRS fallback + when the design-based estimate is unidentifiable or n_valid < min_n. + n_valid : int + Number of non-missing observations in the cell. + used_srs_fallback : bool + True if SRS variance was used instead of design-based. + """ + y_cell = y_full[cell_mask] + w_cell = full_resolved.weights[cell_mask] + # Valid = non-missing AND positive weight (zero-weight rows are padding) + valid = ~np.isnan(y_cell) & (w_cell > 0) + n_valid = int(np.sum(valid)) + + if n_valid == 0: + return np.nan, np.nan, 0, False + + if n_valid < 2: + y_bar = float(y_cell[valid][0]) + return y_bar, np.nan, 1, False + + # Weighted mean from cell members (NaN-safe) + w_valid = w_cell * valid.astype(np.float64) + y_clean = np.where(valid, y_cell, 0.0) + sum_w = float(np.sum(w_valid)) + + if sum_w <= 0: + return np.nan, np.nan, n_valid, False + + y_bar = float(np.sum(w_valid * y_clean) / sum_w) + + # SRS fallback if below min_n threshold + # Normalize positive weights to mean=1 so fallback is scale-invariant + # (replicate designs preserve raw weight scale per survey.py:L189-240) + used_srs = False + if n_valid < min_n: + w_norm = w_valid.copy() + w_pos = w_norm[w_norm > 0] + if len(w_pos) > 0: + w_norm[w_norm > 0] = w_pos / w_pos.mean() + sum_wn = float(np.sum(w_norm)) + resid_sq = w_norm * (y_clean - y_bar) ** 2 + variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1)) + return y_bar, max(variance, 0.0), n_valid, True + + # Full-design domain estimation: construct full-length psi with zeros + # outside the cell, preserving full strata/PSU structure for variance + n_total = len(y_full) + psi = np.zeros(n_total) + # Positions in full array where cell member has valid data + cell_indices = np.where(cell_mask)[0] + valid_positions = cell_indices[valid] + psi[valid_positions] = w_valid[valid] * (y_clean[valid] - y_bar) / sum_w + + # Route to TSL or replicate variance using the full design + if full_resolved.uses_replicate_variance: + variance, _ = compute_replicate_if_variance(psi, full_resolved) + else: + variance = compute_survey_if_variance(psi, full_resolved) + + # SRS fallback when design-based variance is unidentifiable + if np.isnan(variance): + w_norm = w_valid.copy() + w_pos = w_norm[w_norm > 0] + if len(w_pos) > 0: + w_norm[w_norm > 0] = w_pos / w_pos.mean() + sum_wn = float(np.sum(w_norm)) + resid_sq = w_norm * (y_clean - y_bar) ** 2 + variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1)) + used_srs = True + + return y_bar, max(float(variance), 0.0), n_valid, used_srs + + +def aggregate_survey( + data: pd.DataFrame, + by: Union[str, List[str]], + outcomes: Union[str, List[str]], + survey_design: SurveyDesign, + covariates: Optional[Union[str, List[str]]] = None, + min_n: int = 2, + lonely_psu: Optional[str] = None, +) -> Tuple[pd.DataFrame, SurveyDesign]: + """Aggregate survey microdata to geographic-period cells with design-based precision. + + Computes design-weighted cell means and their Taylor-linearized (or + replicate-based) standard errors for each cell defined by the ``by`` + columns. Returns a panel-ready DataFrame with precision weights and a + pre-configured :class:`SurveyDesign` for second-stage DiD estimation. + + Each cell is treated as a subpopulation/domain of the full survey + design: influence function values are zero-padded outside the cell, + preserving full strata/PSU structure for variance estimation per + Lumley (2004) Section 3.4. + + Parameters + ---------- + data : pd.DataFrame + Individual-level microdata. + by : str or list of str + Columns defining cells (e.g., ``["state", "year"]``). The first + element is used as the clustering variable in the returned + SurveyDesign (geographic unit for second-stage inference). + outcomes : str or list of str + Outcome variable(s) to aggregate with full precision tracking. + Each outcome produces ``{name}_mean``, ``{name}_se``, + ``{name}_n``, and ``{name}_precision`` columns. When multiple + outcomes are given, panel filtering (non-estimable cell + removal, zero-weight PSU pruning) is based on the **first** + outcome only, consistent with the returned SurveyDesign. For + independent per-outcome support, call once per outcome. + survey_design : SurveyDesign + Survey design specification for the microdata. + covariates : str or list of str, optional + Additional variables to aggregate as design-weighted means only + (no SE/precision columns). + min_n : int, default 2 + Minimum respondents per cell. Cells below this threshold use + simple random sampling variance as a fallback. + lonely_psu : str, optional + Override the survey design's ``lonely_psu`` setting for within-cell + computation. One of ``"remove"``, ``"certainty"``, ``"adjust"``. + + Returns + ------- + panel_df : pd.DataFrame + Aggregated panel with columns: grouping variables, + ``{outcome}_mean``, ``{outcome}_se``, ``{outcome}_n``, + ``{outcome}_precision``, ``{outcome}_weight``, + ``{covariate}_mean``, ``cell_n``, ``cell_n_eff``, + ``srs_fallback``. The ``_weight`` column is a fit-ready + version of ``_precision`` with NaN/Inf mapped to 0.0. + second_stage_design : SurveyDesign + Pre-configured for second-stage estimation with + ``weight_type="aweight"``, precision weights from the first + outcome, and geographic clustering via ``psu``. + + Examples + -------- + >>> design = SurveyDesign(weights="finalwt", strata="strat", psu="psu") + >>> panel, stage2 = aggregate_survey( + ... microdata, by=["state", "year"], + ... outcomes="smoking_rate", survey_design=design, + ... ) + >>> # Add treatment/time indicators at the panel level, then fit: + >>> # panel["treated"] = ... # e.g., from policy adoption data + >>> # panel["post"] = (panel["year"] >= treatment_year).astype(int) + >>> # result = DifferenceInDifferences().fit( + >>> # panel, outcome="smoking_rate_mean", + >>> # treatment="treated", time="post", survey_design=stage2, + >>> # ) + """ + import warnings + from dataclasses import replace + + # --- Normalize inputs --- + by_cols = [by] if isinstance(by, str) else list(by) + outcome_cols = [outcomes] if isinstance(outcomes, str) else list(outcomes) + cov_cols = ( + [covariates] if isinstance(covariates, str) else list(covariates) if covariates else [] + ) + + # --- Validate --- + if not by_cols: + raise ValueError("'by' must specify at least one grouping column") + if not outcome_cols: + raise ValueError("'outcomes' must specify at least one outcome variable") + + all_cols = by_cols + outcome_cols + cov_cols + missing = [c for c in all_cols if c not in data.columns] + if missing: + raise ValueError(f"Columns not found in DataFrame: {missing}") + + overlap = set(by_cols) & (set(outcome_cols) | set(cov_cols)) + if overlap: + raise ValueError(f"Columns appear in both 'by' and outcomes/covariates: {overlap}") + + if not isinstance(survey_design, SurveyDesign): + raise TypeError( + f"survey_design must be a SurveyDesign instance, got {type(survey_design).__name__}" + ) + + if min_n < 1: + raise ValueError(f"min_n must be >= 1, got {min_n}") + + if lonely_psu is not None and lonely_psu not in ("remove", "certainty", "adjust"): + raise ValueError( + f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{lonely_psu}'" + ) + + # --- Empty-input guard --- + if data.empty: + raise ValueError("data must be non-empty") + + # --- Validate grouping columns have no missing values --- + by_missing = data[by_cols].isna().any() + cols_with_na = list(by_missing[by_missing].index) + if cols_with_na: + raise ValueError( + f"Missing values in grouping column(s): {cols_with_na}. " + f"Drop or fill NaN values before calling aggregate_survey()." + ) + + # --- Resolve design once on full data --- + effective_design = ( + replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design + ) + full_resolved = effective_design.resolve(data) + + # --- Precompute full-length outcome/covariate arrays --- + n_total = len(data) + all_vars = outcome_cols + cov_cols + non_numeric = [v for v in all_vars if not pd.api.types.is_numeric_dtype(data[v])] + if non_numeric: + raise ValueError( + f"Non-numeric column(s) in outcomes/covariates: {non_numeric}. " + f"All outcome and covariate columns must be numeric." + ) + y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars} + + # --- Per-cell computation --- + # Use groupby().indices for position-based cell membership (safe with + # duplicate DataFrame indices, no column injection into user data) + grouped = data.groupby(by_cols, sort=True) + cell_indices = grouped.indices # dict of cell_key → positional indices + rows: List[Dict[str, Any]] = [] + srs_cells: List[str] = [] + zero_var_cells: List[str] = [] + + for cell_key, pos_idx in cell_indices.items(): + # Boolean mask for full-design domain estimation + cell_mask = np.zeros(n_total, dtype=bool) + cell_mask[pos_idx] = True + + cell_n = int(np.sum(cell_mask)) + cell_key_str = str(cell_key) + + # Cell-level statistics (Kish ESS is a property of the cell) + cell_w = full_resolved.weights[cell_mask] + sum_w = float(np.sum(cell_w)) + sum_w2 = float(np.sum(cell_w**2)) + cell_n_eff = (sum_w**2 / sum_w2) if sum_w2 > 0 else 0.0 + + # Build row dict with grouping columns + row: Dict[str, Any] = {} + if len(by_cols) == 1: + row[by_cols[0]] = cell_key + else: + for i, col in enumerate(by_cols): + row[col] = cell_key[i] + + row["cell_n"] = cell_n + row["cell_n_eff"] = cell_n_eff + + cell_srs_fallback = False + + # Outcomes: mean + SE + n + precision (full-design domain estimation) + for var in outcome_cols: + y_bar, variance, n_valid, used_srs = _cell_mean_variance( + y_arrays[var], + full_resolved, + cell_mask, + min_n, + ) + se = float(np.sqrt(variance)) if not np.isnan(variance) else np.nan + + if used_srs: + cell_srs_fallback = True + + # Zero variance → precision NaN + if se == 0.0: + precision = np.nan + zero_var_cells.append(cell_key_str) + elif np.isnan(se): + precision = np.nan + else: + precision = 1.0 / variance + + row[f"{var}_mean"] = y_bar + row[f"{var}_se"] = se + row[f"{var}_n"] = n_valid + row[f"{var}_precision"] = precision + + # Covariates: design-weighted mean only + for var in cov_cols: + y_cell = y_arrays[var][cell_mask] + valid = ~np.isnan(y_cell) + w_valid = cell_w * valid.astype(np.float64) + sw = float(np.sum(w_valid)) + if sw > 0: + row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y_cell, 0.0)) / sw) + else: + row[f"{var}_mean"] = np.nan + + row["srs_fallback"] = cell_srs_fallback + if cell_srs_fallback: + srs_cells.append(cell_key_str) + + rows.append(row) + + # --- Warnings --- + if srs_cells: + warnings.warn( + f"Design-based variance not estimable for {len(srs_cells)} cell(s); " + f"using SRS fallback: {srs_cells[:5]}" + + (f" ... and {len(srs_cells) - 5} more" if len(srs_cells) > 5 else ""), + UserWarning, + stacklevel=2, + ) + if zero_var_cells: + warnings.warn( + f"Zero variance in {len(zero_var_cells)} cell(s) (precision set to NaN): " + f"{zero_var_cells[:5]}" + + (f" ... and {len(zero_var_cells) - 5} more" if len(zero_var_cells) > 5 else ""), + UserWarning, + stacklevel=2, + ) + + # --- Assemble output --- + panel_df = pd.DataFrame(rows) + + # Sort by grouping columns + panel_df = panel_df.sort_values(by_cols).reset_index(drop=True) + + # --- Drop non-estimable cells --- + # Cells with non-finite mean (n_valid==0 or all-missing) cannot contribute + # to second-stage estimation and would cause fit() to reject NaN outcomes. + # Dropping them also removes all-zero-weight PSUs from the panel. + first_outcome = outcome_cols[0] + mean_col = f"{first_outcome}_mean" + nonestimable = ~np.isfinite(panel_df[mean_col].values) + if np.any(nonestimable): + n_dropped = int(np.sum(nonestimable)) + dropped_keys = panel_df.loc[nonestimable, by_cols].values.tolist() + # Warn about secondary outcomes losing valid data in dropped cells + secondary_loss = [] + for var in outcome_cols[1:]: + valid_secondary = np.isfinite(panel_df.loc[nonestimable, f"{var}_mean"].values) + if np.any(valid_secondary): + secondary_loss.append(var) + msg = ( + f"Dropped {n_dropped} non-estimable cell(s) (based on first outcome " + f"'{first_outcome}'): {dropped_keys[:5]}" + + (f" ... and {n_dropped - 5} more" if n_dropped > 5 else "") + ) + if secondary_loss: + msg += ( + f". Note: {secondary_loss} had valid data in dropped cells. " + f"For independent per-outcome support, call once per outcome." + ) + warnings.warn(msg, UserWarning, stacklevel=2) + panel_df = panel_df[~nonestimable].reset_index(drop=True) + + # --- Construct second-stage SurveyDesign --- + # Create a fit-ready weight column: NaN/Inf precision → 0.0 so downstream + # resolve() doesn't reject missing weights. Diagnostic *_precision is kept. + weight_col = f"{first_outcome}_weight" + panel_df[weight_col] = np.where( + np.isfinite(panel_df[f"{first_outcome}_precision"]), + panel_df[f"{first_outcome}_precision"], + 0.0, + ) + + # Drop geographic units (PSUs) with zero total weight — they would + # inflate survey df and distort second-stage variance estimation. + geo_col = by_cols[0] + geo_weight = panel_df.groupby(geo_col)[weight_col].sum() + zero_geos = geo_weight[geo_weight == 0].index + if len(zero_geos) > 0: + n_before = len(panel_df) + panel_df = panel_df[~panel_df[geo_col].isin(zero_geos)].reset_index(drop=True) + n_after = len(panel_df) + warnings.warn( + f"Dropped {n_before - n_after} cell(s) from {len(zero_geos)} " + f"geographic unit(s) with zero total weight: " + f"{list(zero_geos[:5])}" + + (f" ... and {len(zero_geos) - 5} more" if len(zero_geos) > 5 else ""), + UserWarning, + stacklevel=2, + ) + + # Guard: all cells dropped + if panel_df.empty: + raise ValueError( + "No estimable cells remain after aggregation. " + "All cells had missing outcomes or zero effective weight." + ) + + second_stage_design = SurveyDesign( + weights=weight_col, + weight_type="aweight", + psu=geo_col, + ) + + return panel_df, second_stage_design diff --git a/docs/api/prep.rst b/docs/api/prep.rst index 462e82d0..7491c085 100644 --- a/docs/api/prep.rst +++ b/docs/api/prep.rst @@ -250,6 +250,52 @@ Example outcome='outcome' ) +Survey Aggregation +------------------ + +aggregate_survey +~~~~~~~~~~~~~~~~ + +Aggregate survey microdata to geographic-period cells with design-based precision. + +.. autofunction:: diff_diff.aggregate_survey + +Example +^^^^^^^ + +.. code-block:: python + + from diff_diff import aggregate_survey, SurveyDesign, DifferenceInDifferences + + # Define the survey design for the microdata + design = SurveyDesign(weights="finalwt", strata="strat", psu="psu") + + # Aggregate to state-year panel with design-based SEs + panel, stage2 = aggregate_survey( + microdata, + by=["state", "year"], + outcomes="smoking_rate", + covariates=["age", "income"], + survey_design=design, + ) + + # panel has: state, year, smoking_rate_mean, smoking_rate_se, + # smoking_rate_n, smoking_rate_precision, smoking_rate_weight, + # age_mean, income_mean, cell_n, cell_n_eff, srs_fallback + # + # *_weight is fit-ready (NaN precision -> 0.0) + # Non-estimable cells and zero-weight geos are dropped automatically. + # Multi-outcome filtering is keyed off the first outcome. + + # stage2 is pre-configured: aweights + state-level clustering + # Add treatment/time indicators at the panel level, then fit: + # panel["treated"] = ... # from policy adoption data + # panel["post"] = (panel["year"] >= treatment_year).astype(int) + # result = DifferenceInDifferences().fit( + # panel, outcome="smoking_rate_mean", + # treatment="treated", time="post", survey_design=stage2, + # ) + Data Validation --------------- diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 42f0fa5f..576b5e3d 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -2303,6 +2303,31 @@ unequal selection probabilities). Linearization variance estimation, matching the R `survey` package convention that clusters are the primary sampling units. +### Survey Aggregation (`aggregate_survey`) + +Aggregation of individual-level survey microdata to geographic-period cells with +design-based precision estimates, for use as a pre-processing step before panel +DiD estimation on repeated cross-section survey data. + +- **Reference**: Lumley (2004) "Analysis of Complex Survey Samples", Journal of + Statistical Software 9(8), Section 3.4 (domain estimation). +- **Cell mean**: Design-weighted mean `ȳ_g = Σ w_i y_i / Σ w_i` for each cell g + defined by grouping columns (e.g., state × year). +- **Cell variance**: Each cell is treated as a subpopulation/domain of the full + survey design (consistent with `SurveyDesign.subpopulation()` and the + Subpopulation Analysis section below). The influence function + `ψ_i = w_i (y_i - ȳ_g) / Σ w_j` is zero-padded outside the cell, preserving + full strata/PSU structure for variance estimation via `compute_survey_if_variance()` + (TSL) or `compute_replicate_if_variance()` (replicate designs). +- **Precision weight**: `1 / V(ȳ_g)` used as inverse-variance weight (aweight) + in second-stage DiD estimation. +- **Note:** SRS fallback when design-based variance is unidentifiable (e.g., all + strata contribute zero variance) or when the cell has fewer than `min_n` valid + observations. Formula: `V_SRS = Σ w_i(y_i - ȳ)² / (Σ w_j)² × n/(n-1)`. + Cells using SRS fallback are flagged via `srs_fallback` column. +- **Edge case**: Zero-variance cells (all observations identical) set precision to + NaN to avoid infinite weights in second-stage WLS. + ### Survey-Aware Bootstrap (Phase 6) Two strategies for bootstrap variance under complex survey designs: diff --git a/tests/test_doc_snippets.py b/tests/test_doc_snippets.py index cb4296da..8b02b565 100644 --- a/tests/test_doc_snippets.py +++ b/tests/test_doc_snippets.py @@ -97,6 +97,7 @@ def _extract_snippets(rst_path: Path) -> List[Tuple[int, str]]: r"pip\s+install", r"wild_bootstrap_se\(X,", # low-level array API pseudo-code r"wide_to_long\(", # references undefined wide_data variable + r"aggregate_survey\(", # references undefined microdata variable ] # Third-party packages imported by comparison-page snippets that may not diff --git a/tests/test_prep.py b/tests/test_prep.py index 0674186a..e08c48e2 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -7,6 +7,7 @@ import pytest from diff_diff.prep import ( + aggregate_survey, aggregate_to_cohorts, balance_panel, create_event_time, @@ -17,6 +18,7 @@ validate_did_data, wide_to_long, ) +from diff_diff.survey import SurveyDesign class TestMakeTreatmentIndicator: @@ -102,10 +104,9 @@ def test_treatment_start(self): def test_datetime_column(self): """Test with datetime column.""" - df = pd.DataFrame({ - "date": pd.to_datetime(["2020-01-01", "2020-06-01", "2021-01-01"]), - "y": [1, 2, 3] - }) + df = pd.DataFrame( + {"date": pd.to_datetime(["2020-01-01", "2020-06-01", "2021-01-01"]), "y": [1, 2, 3]} + ) result = make_post_indicator(df, "date", treatment_start="2020-06-01") assert result["post"].tolist() == [0, 1, 1] @@ -133,50 +134,36 @@ class TestWideToLong: def test_basic_conversion(self): """Test basic wide to long conversion.""" - wide_df = pd.DataFrame({ - "firm_id": [1, 2], - "sales_2019": [100, 150], - "sales_2020": [110, 160], - "sales_2021": [120, 170] - }) + wide_df = pd.DataFrame( + { + "firm_id": [1, 2], + "sales_2019": [100, 150], + "sales_2020": [110, 160], + "sales_2021": [120, 170], + } + ) result = wide_to_long( wide_df, value_columns=["sales_2019", "sales_2020", "sales_2021"], id_column="firm_id", time_name="year", - value_name="sales" + value_name="sales", ) assert len(result) == 6 assert set(result.columns) == {"firm_id", "year", "sales"} def test_with_time_values(self): """Test with explicit time values.""" - wide_df = pd.DataFrame({ - "id": [1], - "t1": [10], - "t2": [20] - }) + wide_df = pd.DataFrame({"id": [1], "t1": [10], "t2": [20]}) result = wide_to_long( - wide_df, - value_columns=["t1", "t2"], - id_column="id", - time_values=[2020, 2021] + wide_df, value_columns=["t1", "t2"], id_column="id", time_values=[2020, 2021] ) assert result["period"].tolist() == [2020, 2021] def test_preserves_other_columns(self): """Test that other columns are preserved.""" - wide_df = pd.DataFrame({ - "id": [1, 2], - "group": ["A", "B"], - "t1": [10, 20], - "t2": [15, 25] - }) - result = wide_to_long( - wide_df, - value_columns=["t1", "t2"], - id_column="id" - ) + wide_df = pd.DataFrame({"id": [1, 2], "group": ["A", "B"], "t1": [10, 20], "t2": [15, 25]}) + result = wide_to_long(wide_df, value_columns=["t1", "t2"], id_column="id") assert "group" in result.columns assert result[result["id"] == 1]["group"].tolist() == ["A", "A"] @@ -198,32 +185,26 @@ class TestBalancePanel: def test_inner_balance(self): """Test inner balance (keep complete units only).""" - df = pd.DataFrame({ - "unit": [1, 1, 1, 2, 2, 3, 3, 3], - "period": [1, 2, 3, 1, 2, 1, 2, 3], - "y": [10, 11, 12, 20, 21, 30, 31, 32] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 1, 2, 2, 3, 3, 3], + "period": [1, 2, 3, 1, 2, 1, 2, 3], + "y": [10, 11, 12, 20, 21, 30, 31, 32], + } + ) result = balance_panel(df, "unit", "period", method="inner") assert set(result["unit"].unique()) == {1, 3} assert len(result) == 6 def test_outer_balance(self): """Test outer balance (include all combinations).""" - df = pd.DataFrame({ - "unit": [1, 1, 2], - "period": [1, 2, 1], - "y": [10, 11, 20] - }) + df = pd.DataFrame({"unit": [1, 1, 2], "period": [1, 2, 1], "y": [10, 11, 20]}) result = balance_panel(df, "unit", "period", method="outer") assert len(result) == 4 # 2 units x 2 periods def test_fill_with_value(self): """Test fill method with specific value.""" - df = pd.DataFrame({ - "unit": [1, 1, 2], - "period": [1, 2, 1], - "y": [10.0, 11.0, 20.0] - }) + df = pd.DataFrame({"unit": [1, 1, 2], "period": [1, 2, 1], "y": [10.0, 11.0, 20.0]}) result = balance_panel(df, "unit", "period", method="fill", fill_value=0.0) assert len(result) == 4 missing_row = result[(result["unit"] == 2) & (result["period"] == 2)] @@ -231,11 +212,13 @@ def test_fill_with_value(self): def test_fill_forward_backward(self): """Test fill method with forward/backward fill.""" - df = pd.DataFrame({ - "unit": [1, 1, 1, 2, 2], - "period": [1, 2, 3, 1, 3], # Unit 2 missing period 2 - "y": [10.0, 11.0, 12.0, 20.0, 22.0] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 1, 2, 2], + "period": [1, 2, 3, 1, 3], # Unit 2 missing period 2 + "y": [10.0, 11.0, 12.0, 20.0, 22.0], + } + ) result = balance_panel(df, "unit", "period", method="fill", fill_value=None) assert len(result) == 6 # Check that unit 2, period 2 was filled @@ -255,11 +238,9 @@ class TestValidateDidData: def test_valid_data(self): """Test validation of valid data.""" - df = pd.DataFrame({ - "y": [1.0, 2.0, 3.0, 4.0], - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1] - }) + df = pd.DataFrame( + {"y": [1.0, 2.0, 3.0, 4.0], "treated": [0, 0, 1, 1], "post": [0, 1, 0, 1]} + ) result = validate_did_data(df, "y", "treated", "post", raise_on_error=False) assert result["valid"] is True assert len(result["errors"]) == 0 @@ -273,33 +254,25 @@ def test_missing_column(self): def test_non_numeric_outcome(self): """Test validation catches non-numeric outcome.""" - df = pd.DataFrame({ - "y": ["a", "b", "c", "d"], - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1] - }) + df = pd.DataFrame( + {"y": ["a", "b", "c", "d"], "treated": [0, 0, 1, 1], "post": [0, 1, 0, 1]} + ) result = validate_did_data(df, "y", "treated", "post", raise_on_error=False) assert result["valid"] is False assert any("numeric" in e for e in result["errors"]) def test_non_binary_treatment(self): """Test validation catches non-binary treatment.""" - df = pd.DataFrame({ - "y": [1.0, 2.0, 3.0], - "treated": [0, 1, 2], - "post": [0, 1, 0] - }) + df = pd.DataFrame({"y": [1.0, 2.0, 3.0], "treated": [0, 1, 2], "post": [0, 1, 0]}) result = validate_did_data(df, "y", "treated", "post", raise_on_error=False) assert result["valid"] is False assert any("binary" in e for e in result["errors"]) def test_missing_values(self): """Test validation catches missing values.""" - df = pd.DataFrame({ - "y": [1.0, np.nan, 3.0, 4.0], - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1] - }) + df = pd.DataFrame( + {"y": [1.0, np.nan, 3.0, 4.0], "treated": [0, 0, 1, 1], "post": [0, 1, 0, 1]} + ) result = validate_did_data(df, "y", "treated", "post", raise_on_error=False) assert result["valid"] is False assert any("missing" in e for e in result["errors"]) @@ -312,12 +285,14 @@ def test_raises_on_error(self): def test_panel_validation(self): """Test panel-specific validation.""" - df = pd.DataFrame({ - "y": [1.0, 2.0, 3.0, 4.0], - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1], - "unit": [1, 1, 2, 2] - }) + df = pd.DataFrame( + { + "y": [1.0, 2.0, 3.0, 4.0], + "treated": [0, 0, 1, 1], + "post": [0, 1, 0, 1], + "unit": [1, 1, 2, 2], + } + ) result = validate_did_data(df, "y", "treated", "post", unit="unit", raise_on_error=False) assert result["valid"] is True assert result["summary"]["n_units"] == 2 @@ -328,21 +303,25 @@ class TestSummarizeDidData: def test_basic_summary(self): """Test basic summary statistics.""" - df = pd.DataFrame({ - "y": [10, 11, 12, 13, 20, 21, 22, 23], - "treated": [0, 0, 1, 1, 0, 0, 1, 1], - "post": [0, 1, 0, 1, 0, 1, 0, 1] - }) + df = pd.DataFrame( + { + "y": [10, 11, 12, 13, 20, 21, 22, 23], + "treated": [0, 0, 1, 1, 0, 0, 1, 1], + "post": [0, 1, 0, 1, 0, 1, 0, 1], + } + ) summary = summarize_did_data(df, "y", "treated", "post") assert len(summary) == 5 # 4 groups + DiD estimate def test_did_estimate_included(self): """Test that DiD estimate is calculated.""" - df = pd.DataFrame({ - "y": [10, 20, 15, 30], # Perfect DiD = 30-15 - (20-10) = 5 - "treated": [0, 0, 1, 1], - "post": [0, 1, 0, 1] - }) + df = pd.DataFrame( + { + "y": [10, 20, 15, 30], # Perfect DiD = 30-15 - (20-10) = 5 + "treated": [0, 0, 1, 1], + "post": [0, 1, 0, 1], + } + ) summary = summarize_did_data(df, "y", "treated", "post") assert "DiD Estimate" in summary.index assert summary.loc["DiD Estimate", "mean"] == 5.0 @@ -369,11 +348,7 @@ def test_treatment_effect_recovery(self): true_effect = 5.0 data = generate_did_data( - n_units=200, - n_periods=4, - treatment_effect=true_effect, - noise_sd=0.5, - seed=42 + n_units=200, n_periods=4, treatment_effect=true_effect, noise_sd=0.5, seed=42 ) did = DifferenceInDifferences() @@ -405,21 +380,25 @@ class TestCreateEventTime: def test_basic_event_time(self): """Test basic event time calculation.""" - df = pd.DataFrame({ - "unit": [1, 1, 1, 2, 2, 2], - "year": [2018, 2019, 2020, 2018, 2019, 2020], - "treatment_year": [2019, 2019, 2019, 2020, 2020, 2020] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 1, 2, 2, 2], + "year": [2018, 2019, 2020, 2018, 2019, 2020], + "treatment_year": [2019, 2019, 2019, 2020, 2020, 2020], + } + ) result = create_event_time(df, "year", "treatment_year") assert result["event_time"].tolist() == [-1, 0, 1, -2, -1, 0] def test_never_treated(self): """Test handling of never-treated units.""" - df = pd.DataFrame({ - "unit": [1, 1, 2, 2], - "year": [2019, 2020, 2019, 2020], - "treatment_year": [2020, 2020, np.nan, np.nan] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 2, 2], + "year": [2019, 2020, 2019, 2020], + "treatment_year": [2020, 2020, np.nan, np.nan], + } + ) result = create_event_time(df, "year", "treatment_year") assert result.loc[0, "event_time"] == -1 assert result.loc[1, "event_time"] == 0 @@ -428,10 +407,7 @@ def test_never_treated(self): def test_custom_column_name(self): """Test custom output column name.""" - df = pd.DataFrame({ - "year": [2019, 2020], - "treat_time": [2020, 2020] - }) + df = pd.DataFrame({"year": [2019, 2020], "treat_time": [2020, 2020]}) result = create_event_time(df, "year", "treat_time", new_column="rel_time") assert "rel_time" in result.columns @@ -441,12 +417,14 @@ class TestAggregateToCohorts: def test_basic_aggregation(self): """Test basic cohort aggregation.""" - df = pd.DataFrame({ - "unit": [1, 1, 2, 2, 3, 3, 4, 4], - "period": [0, 1, 0, 1, 0, 1, 0, 1], - "treated": [1, 1, 1, 1, 0, 0, 0, 0], - "y": [10, 15, 12, 17, 8, 10, 9, 11] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 2, 2, 3, 3, 4, 4], + "period": [0, 1, 0, 1, 0, 1, 0, 1], + "treated": [1, 1, 1, 1, 0, 0, 0, 0], + "y": [10, 15, 12, 17, 8, 10, 9, 11], + } + ) result = aggregate_to_cohorts(df, "unit", "period", "treated", "y") assert len(result) == 4 # 2 treatment groups x 2 periods assert "mean_y" in result.columns @@ -454,13 +432,15 @@ def test_basic_aggregation(self): def test_with_covariates(self): """Test aggregation with covariates.""" - df = pd.DataFrame({ - "unit": [1, 1, 2, 2], - "period": [0, 1, 0, 1], - "treated": [1, 1, 0, 0], - "y": [10, 15, 8, 10], - "x": [1.0, 1.5, 0.5, 0.8] - }) + df = pd.DataFrame( + { + "unit": [1, 1, 2, 2], + "period": [0, 1, 0, 1], + "treated": [1, 1, 0, 0], + "y": [10, 15, 8, 10], + "x": [1.0, 1.5, 0.5, 0.8], + } + ) result = aggregate_to_cohorts(df, "unit", "period", "treated", "y", covariates=["x"]) assert "x" in result.columns @@ -478,7 +458,7 @@ def test_basic_ranking(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) assert "quality_score" in result.columns assert "outcome_trend_score" in result.columns @@ -502,7 +482,7 @@ def test_with_covariates(self): time_column="period", outcome_column="outcome", treatment_column="treated", - covariates=["x1"] + covariates=["x1"], ) assert not result["covariate_score"].isna().all() @@ -517,7 +497,7 @@ def test_explicit_treated_units(self): unit_column="unit", time_column="period", outcome_column="outcome", - treated_units=[0, 1, 2] + treated_units=[0, 1, 2], ) # Should not include treated units in ranking assert 0 not in result["unit"].values @@ -536,7 +516,7 @@ def test_exclude_units(self): time_column="period", outcome_column="outcome", treatment_column="treated", - exclude_units=[15, 16, 17] + exclude_units=[15, 16, 17], ) assert 15 not in result["unit"].values assert 16 not in result["unit"].values @@ -559,7 +539,7 @@ def test_require_units(self): outcome_column="outcome", treatment_column="treated", require_units=require, - n_top=5 + n_top=5, ) # Required units should be present for u in require: @@ -579,7 +559,7 @@ def test_n_top_limit(self): time_column="period", outcome_column="outcome", treatment_column="treated", - n_top=10 + n_top=10, ) assert len(result) == 10 @@ -597,7 +577,7 @@ def test_suggest_treatment_candidates(self): time_column="period", outcome_column="outcome", suggest_treatment_candidates=True, - n_treatment_candidates=5 + n_treatment_candidates=5, ) assert "treatment_candidate_score" in result.columns assert len(result) == 5 @@ -614,7 +594,7 @@ def test_original_unchanged(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) assert data.columns.tolist() == original_cols @@ -626,10 +606,7 @@ def test_error_missing_column(self): with pytest.raises(ValueError, match="not found"): rank_control_units( - data, - unit_column="missing_col", - time_column="period", - outcome_column="outcome" + data, unit_column="missing_col", time_column="period", outcome_column="outcome" ) def test_error_both_treatment_specs(self): @@ -645,7 +622,7 @@ def test_error_both_treatment_specs(self): time_column="period", outcome_column="outcome", treatment_column="treated", - treated_units=[0, 1] + treated_units=[0, 1], ) def test_error_require_and_exclude_same_unit(self): @@ -662,7 +639,7 @@ def test_error_require_and_exclude_same_unit(self): outcome_column="outcome", treatment_column="treated", require_units=[5], - exclude_units=[5] + exclude_units=[5], ) def test_synthetic_weight_sum(self): @@ -676,7 +653,7 @@ def test_synthetic_weight_sum(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) # Synthetic weights should sum to approximately 1 @@ -694,7 +671,7 @@ def test_pre_periods_explicit(self): time_column="period", outcome_column="outcome", treatment_column="treated", - pre_periods=[0, 1] # Only use first two periods + pre_periods=[0, 1], # Only use first two periods ) assert len(result) > 0 @@ -715,7 +692,7 @@ def test_weight_parameters(self): treatment_column="treated", covariates=["x1"], outcome_weight=1.0, - covariate_weight=0.0 + covariate_weight=0.0, ) # All weight on covariates @@ -727,7 +704,7 @@ def test_weight_parameters(self): treatment_column="treated", covariates=["x1"], outcome_weight=0.0, - covariate_weight=1.0 + covariate_weight=1.0, ) # Rankings should differ @@ -745,10 +722,7 @@ def test_unbalanced_panel(self): # Remove all pre-period data for one control unit control_units = data[data["treated"] == 0]["unit"].unique() unit_to_partially_remove = control_units[0] - mask = ~( - (data["unit"] == unit_to_partially_remove) & - (data["period"] < 3) - ) + mask = ~((data["unit"] == unit_to_partially_remove) & (data["period"] < 3)) unbalanced_data = data[mask].copy() result = rank_control_units( @@ -756,7 +730,7 @@ def test_unbalanced_panel(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) # Should still work and exclude the unit with no pre-treatment data @@ -776,8 +750,7 @@ def test_single_control_unit(self): single_control = control_units[0] filtered_data = data[ - (data["unit"].isin(treated_units)) | - (data["unit"] == single_control) + (data["unit"].isin(treated_units)) | (data["unit"] == single_control) ].copy() result = rank_control_units( @@ -785,7 +758,7 @@ def test_single_control_unit(self): unit_column="unit", time_column="period", outcome_column="outcome", - treatment_column="treated" + treatment_column="treated", ) assert len(result) == 1 @@ -804,7 +777,13 @@ def test_basic_generation(self): data = generate_staggered_data(n_units=50, n_periods=8, seed=42) assert len(data) == 400 # 50 units x 8 periods assert set(data.columns) == { - "unit", "period", "outcome", "first_treat", "treated", "treat", "true_effect" + "unit", + "period", + "outcome", + "first_treat", + "treated", + "treat", + "true_effect", } def test_never_treated_fraction(self): @@ -819,9 +798,7 @@ def test_cohort_periods(self): """Test custom cohort periods.""" from diff_diff.prep import generate_staggered_data - data = generate_staggered_data( - n_units=100, n_periods=10, cohort_periods=[4, 6], seed=42 - ) + data = generate_staggered_data(n_units=100, n_periods=10, cohort_periods=[4, 6], seed=42) cohorts = data.groupby("unit")["first_treat"].first().unique() assert set(cohorts) == {0, 4, 6} @@ -829,9 +806,7 @@ def test_treatment_effect_direction(self): """Test that treatment effect is positive.""" from diff_diff.prep import generate_staggered_data - data = generate_staggered_data( - n_units=100, treatment_effect=3.0, noise_sd=0.1, seed=42 - ) + data = generate_staggered_data(n_units=100, treatment_effect=3.0, noise_sd=0.1, seed=42) # Treated observations should have positive true_effect treated_effects = data[data["treated"] == 1]["true_effect"] assert (treated_effects > 0).all() @@ -841,8 +816,12 @@ def test_dynamic_effects(self): from diff_diff.prep import generate_staggered_data data = generate_staggered_data( - n_units=50, n_periods=10, treatment_effect=2.0, - dynamic_effects=True, effect_growth=0.1, seed=42 + n_units=50, + n_periods=10, + treatment_effect=2.0, + dynamic_effects=True, + effect_growth=0.1, + seed=42, ) # Effects should grow over time since treatment # Check a treated unit @@ -881,9 +860,7 @@ def test_basic_generation(self): data = generate_factor_data(n_units=30, n_pre=8, n_post=4, n_treated=5, seed=42) assert len(data) == 360 # 30 units x 12 periods - assert set(data.columns) == { - "unit", "period", "outcome", "treated", "treat", "true_effect" - } + assert set(data.columns) == {"unit", "period", "outcome", "treated", "treat", "true_effect"} def test_treated_units_count(self): """Test that n_treated is respected.""" @@ -908,9 +885,14 @@ def test_treatment_effect_recovery(self): true_effect = 3.0 data = generate_factor_data( - n_units=100, n_pre=10, n_post=5, n_treated=30, - treatment_effect=true_effect, noise_sd=0.1, factor_strength=0.1, - seed=42 + n_units=100, + n_pre=10, + n_post=5, + n_treated=30, + treatment_effect=true_effect, + noise_sd=0.1, + factor_strength=0.1, + seed=42, ) # Simple DiD on treated vs control, post vs pre treated_post = data[(data["treat"] == 1) & (data["period"] >= 10)]["outcome"].mean() @@ -994,17 +976,35 @@ def test_treatment_effect_recovery(self): from diff_diff.prep import generate_ddd_data true_effect = 3.0 - data = generate_ddd_data(n_per_cell=200, treatment_effect=true_effect, noise_sd=0.5, seed=42) + data = generate_ddd_data( + n_per_cell=200, treatment_effect=true_effect, noise_sd=0.5, seed=42 + ) # Manual DDD calculation - y_111 = data[(data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 1)]["outcome"].mean() - y_110 = data[(data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 0)]["outcome"].mean() - y_101 = data[(data["group"] == 1) & (data["partition"] == 0) & (data["time"] == 1)]["outcome"].mean() - y_100 = data[(data["group"] == 1) & (data["partition"] == 0) & (data["time"] == 0)]["outcome"].mean() - y_011 = data[(data["group"] == 0) & (data["partition"] == 1) & (data["time"] == 1)]["outcome"].mean() - y_010 = data[(data["group"] == 0) & (data["partition"] == 1) & (data["time"] == 0)]["outcome"].mean() - y_001 = data[(data["group"] == 0) & (data["partition"] == 0) & (data["time"] == 1)]["outcome"].mean() - y_000 = data[(data["group"] == 0) & (data["partition"] == 0) & (data["time"] == 0)]["outcome"].mean() + y_111 = data[(data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 1)][ + "outcome" + ].mean() + y_110 = data[(data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 0)][ + "outcome" + ].mean() + y_101 = data[(data["group"] == 1) & (data["partition"] == 0) & (data["time"] == 1)][ + "outcome" + ].mean() + y_100 = data[(data["group"] == 1) & (data["partition"] == 0) & (data["time"] == 0)][ + "outcome" + ].mean() + y_011 = data[(data["group"] == 0) & (data["partition"] == 1) & (data["time"] == 1)][ + "outcome" + ].mean() + y_010 = data[(data["group"] == 0) & (data["partition"] == 1) & (data["time"] == 0)][ + "outcome" + ].mean() + y_001 = data[(data["group"] == 0) & (data["partition"] == 0) & (data["time"] == 1)][ + "outcome" + ].mean() + y_000 = data[(data["group"] == 0) & (data["partition"] == 0) & (data["time"] == 0)][ + "outcome" + ].mean() manual_ddd = (y_111 - y_110) - (y_101 - y_100) - (y_011 - y_010) + (y_001 - y_000) assert abs(manual_ddd - true_effect) < 0.5 @@ -1027,9 +1027,7 @@ def test_basic_generation(self): data = generate_panel_data(n_units=50, n_periods=6, seed=42) assert len(data) == 300 # 50 units x 6 periods - assert set(data.columns) == { - "unit", "period", "treated", "post", "outcome", "true_effect" - } + assert set(data.columns) == {"unit", "period", "treated", "post", "outcome", "true_effect"} def test_treatment_fraction(self): """Test that treatment_fraction is respected.""" @@ -1072,8 +1070,12 @@ def test_non_parallel_trends(self): from diff_diff.prep import generate_panel_data data = generate_panel_data( - n_units=200, n_periods=8, parallel_trends=False, - trend_violation=1.0, noise_sd=0.1, seed=42 + n_units=200, + n_periods=8, + parallel_trends=False, + trend_violation=1.0, + noise_sd=0.1, + seed=42, ) # Calculate pre-treatment trends pre_data = data[data["post"] == 0] @@ -1116,7 +1118,13 @@ def test_basic_generation(self): data = generate_event_study_data(n_units=100, n_pre=5, n_post=5, seed=42) assert len(data) == 1000 # 100 units x 10 periods assert set(data.columns) == { - "unit", "period", "treated", "post", "outcome", "event_time", "true_effect" + "unit", + "period", + "treated", + "post", + "outcome", + "event_time", + "true_effect", } def test_event_time(self): @@ -1143,8 +1151,7 @@ def test_treatment_effect_recovery(self): true_effect = 4.0 data = generate_event_study_data( - n_units=500, n_pre=5, n_post=5, treatment_effect=true_effect, - noise_sd=0.5, seed=42 + n_units=500, n_pre=5, n_post=5, treatment_effect=true_effect, noise_sd=0.5, seed=42 ) # Simple DiD @@ -1182,8 +1189,18 @@ def test_basic_shape_and_columns(self): data = generate_survey_did_data(n_units=100, n_periods=4, cohort_periods=[2, 3], seed=42) assert len(data) == 400 # 100 units x 4 periods - expected = {"unit", "period", "outcome", "first_treat", "treated", - "true_effect", "stratum", "psu", "fpc", "weight"} + expected = { + "unit", + "period", + "outcome", + "first_treat", + "treated", + "true_effect", + "stratum", + "psu", + "fpc", + "weight", + } assert set(data.columns) == expected def test_survey_columns_valid(self): @@ -1237,8 +1254,10 @@ def test_replicate_weights(self): n_strata, psu_per = 3, 4 data = generate_survey_did_data( - n_strata=n_strata, psu_per_stratum=psu_per, - include_replicate_weights=True, seed=42, + n_strata=n_strata, + psu_per_stratum=psu_per, + include_replicate_weights=True, + seed=42, ) n_psu = n_strata * psu_per rep_cols = [c for c in data.columns if c.startswith("rep_")] @@ -1246,7 +1265,7 @@ def test_replicate_weights(self): # Each replicate should zero out one PSU for r in range(n_psu): - assert (data.loc[data[f"rep_{r}"] == 0, "psu"].nunique() == 1) + assert data.loc[data[f"rep_{r}"] == 0, "psu"].nunique() == 1 def test_covariates(self): """Test covariate columns are added when requested.""" @@ -1277,7 +1296,9 @@ def test_treatment_structure(self): from diff_diff.prep import generate_survey_did_data data = generate_survey_did_data( - cohort_periods=[3, 5], never_treated_frac=0.3, seed=42, + cohort_periods=[3, 5], + never_treated_frac=0.3, + seed=42, ) cohorts = set(data.groupby("unit")["first_treat"].first().unique()) assert 0 in cohorts # never-treated @@ -1308,8 +1329,10 @@ def test_jk1_minimum_psu_guard(self): # Configured count: 1 PSU total with pytest.raises(ValueError, match="at least 2 PSUs"): generate_survey_did_data( - n_strata=1, psu_per_stratum=1, - include_replicate_weights=True, seed=42, + n_strata=1, + psu_per_stratum=1, + include_replicate_weights=True, + seed=42, ) def test_jk1_one_populated_psu_guard(self): @@ -1320,9 +1343,13 @@ def test_jk1_one_populated_psu_guard(self): # 2 configured PSUs but only 1 unit -> only 1 populated PSU with pytest.raises(ValueError, match="at least 2 populated PSUs"): generate_survey_did_data( - n_units=1, n_strata=1, psu_per_stratum=2, - cohort_periods=[2], n_periods=4, - include_replicate_weights=True, seed=42, + n_units=1, + n_strata=1, + psu_per_stratum=2, + cohort_periods=[2], + n_periods=4, + include_replicate_weights=True, + seed=42, ) def test_repeated_cross_section(self): @@ -1330,7 +1357,11 @@ def test_repeated_cross_section(self): from diff_diff.prep import generate_survey_did_data data = generate_survey_did_data( - n_units=20, n_periods=4, cohort_periods=[2], panel=False, seed=42, + n_units=20, + n_periods=4, + cohort_periods=[2], + panel=False, + seed=42, ) assert len(data) == 80 assert data["unit"].nunique() == 80 # unique across all periods @@ -1457,11 +1488,18 @@ def test_psu_period_factor_deff_regression(self): warnings.filterwarnings("ignore") df = generate_survey_did_data( - n_units=200, n_periods=8, cohort_periods=[3, 5], - never_treated_frac=0.3, treatment_effect=2.0, - n_strata=5, psu_per_stratum=8, fpc_per_stratum=200.0, - weight_variation="moderate", psu_re_sd=2.0, - psu_period_factor=1.0, seed=42, + n_units=200, + n_periods=8, + cohort_periods=[3, 5], + never_treated_frac=0.3, + treatment_effect=2.0, + n_strata=5, + psu_per_stratum=8, + fpc_per_stratum=200.0, + weight_variation="moderate", + psu_re_sd=2.0, + psu_period_factor=1.0, + seed=42, ) sd = SurveyDesign(weights="weight", strata="stratum", psu="psu", fpc="fpc") @@ -1472,21 +1510,22 @@ def test_psu_period_factor_deff_regression(self): did = DifferenceInDifferences() r_naive = did.fit(c3, outcome="outcome", treatment="treat", time="post") r_survey = did.fit( - c3, outcome="outcome", treatment="treat", time="post", + c3, + outcome="outcome", + treatment="treat", + time="post", survey_design=sd, ) - assert r_survey.se > r_naive.se, ( - f"Survey SE ({r_survey.se:.4f}) should exceed naive SE ({r_naive.se:.4f})" - ) + assert ( + r_survey.se > r_naive.se + ), f"Survey SE ({r_survey.se:.4f}) should exceed naive SE ({r_naive.se:.4f})" # DEFF for treat_x_post must be > 1 c3["treat_x_post"] = c3["treat"] * c3["post"] resolved = sd.resolve(c3) reg = LinearRegression(include_intercept=True, survey_design=resolved) reg.fit(X=c3[["treat", "post", "treat_x_post"]].values, y=c3["outcome"].values) - deff = reg.compute_deff( - coefficient_names=["intercept", "treat", "post", "treat_x_post"] - ) + deff = reg.compute_deff(coefficient_names=["intercept", "treat", "post", "treat_x_post"]) txp_deff = deff.deff[3] # treat_x_post assert txp_deff > 1.0, f"DEFF for treat_x_post ({txp_deff:.2f}) should be > 1" @@ -1514,9 +1553,7 @@ def test_icc_parameter(self): from diff_diff.prep_dgp import generate_survey_did_data target_icc = 0.3 - df = generate_survey_did_data( - n_units=1000, icc=target_icc, seed=42 - ) + df = generate_survey_did_data(n_units=1000, icc=target_icc, seed=42) # ANOVA-based ICC on period 1 (pre-treatment, no TE contamination) p1 = df[df["period"] == 1] groups = p1.groupby("psu")["outcome"] @@ -1561,9 +1598,7 @@ def test_weight_cv_parameter(self): from diff_diff.prep_dgp import generate_survey_did_data target_cv = 0.5 - df = generate_survey_did_data( - n_units=1000, weight_cv=target_cv, seed=42 - ) + df = generate_survey_did_data(n_units=1000, weight_cv=target_cv, seed=42) weights = df.groupby("unit")["weight"].first().values realized_cv = weights.std() / weights.mean() assert abs(realized_cv - target_cv) < 0.15 @@ -1573,9 +1608,7 @@ def test_weight_cv_and_weight_variation_conflict(self): from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="Cannot specify both weight_cv"): - generate_survey_did_data( - weight_cv=0.5, weight_variation="high", seed=42 - ) + generate_survey_did_data(weight_cv=0.5, weight_variation="high", seed=42) def test_weight_cv_nan_inf(self): """weight_cv must reject NaN and Inf.""" @@ -1622,8 +1655,7 @@ def test_informative_sampling_default_weights(self): expected_mean = 1.0 + 1.0 * (s / 4) stratum_weights = p1.loc[p1["stratum"] == s, "weight"] assert abs(stratum_weights.mean() - expected_mean) < 0.15, ( - f"Stratum {s}: expected mean ~{expected_mean}, " - f"got {stratum_weights.mean():.3f}" + f"Stratum {s}: expected mean ~{expected_mean}, " f"got {stratum_weights.mean():.3f}" ) # Within-stratum variation should exist (informative sampling) assert stratum_weights.std() > 0.01 @@ -1670,9 +1702,7 @@ def test_icc_with_covariates(self): from diff_diff.prep_dgp import generate_survey_did_data target_icc = 0.3 - df = generate_survey_did_data( - n_units=1000, icc=target_icc, add_covariates=True, seed=42 - ) + df = generate_survey_did_data(n_units=1000, icc=target_icc, add_covariates=True, seed=42) # ANOVA-based ICC on period 1 p1 = df[df["period"] == 1] groups = p1.groupby("psu")["outcome"] @@ -1840,9 +1870,7 @@ def test_strata_sizes(self): from diff_diff.prep_dgp import generate_survey_did_data sizes = [60, 50, 40, 30, 20] - df = generate_survey_did_data( - n_units=200, strata_sizes=sizes, seed=42 - ) + df = generate_survey_did_data(n_units=200, strata_sizes=sizes, seed=42) for s, expected in enumerate(sizes): actual = df[df["period"] == 1]["stratum"].value_counts().get(s, 0) assert actual == expected @@ -1852,9 +1880,7 @@ def test_strata_sizes_sum_mismatch(self): from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="strata_sizes must sum"): - generate_survey_did_data( - n_units=200, strata_sizes=[50, 50, 50, 50, 49], seed=42 - ) + generate_survey_did_data(n_units=200, strata_sizes=[50, 50, 50, 50, 49], seed=42) def test_strata_sizes_float_rejected(self): """strata_sizes must contain integers, not floats.""" @@ -1877,12 +1903,12 @@ def test_covariate_effects_custom(self): """Custom covariate coefficients should change outcome variance.""" from diff_diff.prep_dgp import generate_survey_did_data - df_default = generate_survey_did_data( - n_units=500, add_covariates=True, seed=42 - ) + df_default = generate_survey_did_data(n_units=500, add_covariates=True, seed=42) df_large = generate_survey_did_data( - n_units=500, add_covariates=True, - covariate_effects=(2.0, 1.0), seed=42, + n_units=500, + add_covariates=True, + covariate_effects=(2.0, 1.0), + seed=42, ) # Larger coefficients → larger outcome variance assert df_large["outcome"].var() > df_default["outcome"].var() @@ -1891,12 +1917,12 @@ def test_covariate_effects_zero(self): """Zero covariate effects should produce same variance as no covariates.""" from diff_diff.prep_dgp import generate_survey_did_data - df_no_cov = generate_survey_did_data( - n_units=500, add_covariates=False, seed=42 - ) + df_no_cov = generate_survey_did_data(n_units=500, add_covariates=False, seed=42) df_zero = generate_survey_did_data( - n_units=500, add_covariates=True, - covariate_effects=(0.0, 0.0), seed=42, + n_units=500, + add_covariates=True, + covariate_effects=(0.0, 0.0), + seed=42, ) # Outcome variance should be similar (covariates contribute nothing) assert abs(df_zero["outcome"].var() - df_no_cov["outcome"].var()) < 0.5 @@ -1920,32 +1946,952 @@ def test_te_covariate_interaction_requires_covariates(self): from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="te_covariate_interaction requires"): - generate_survey_did_data( - te_covariate_interaction=0.5, add_covariates=False, seed=42 - ) + generate_survey_did_data(te_covariate_interaction=0.5, add_covariates=False, seed=42) def test_covariate_effects_validation(self): """covariate_effects must be length 2 and finite.""" from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="covariate_effects must have length 2"): - generate_survey_did_data( - add_covariates=True, covariate_effects=(1.0,), seed=42 - ) + generate_survey_did_data(add_covariates=True, covariate_effects=(1.0,), seed=42) with pytest.raises(ValueError, match="covariate_effects must be finite"): - generate_survey_did_data( - add_covariates=True, covariate_effects=(np.nan, 0.3), seed=42 - ) + generate_survey_did_data(add_covariates=True, covariate_effects=(np.nan, 0.3), seed=42) with pytest.raises(ValueError, match="covariate_effects must be finite"): - generate_survey_did_data( - add_covariates=True, covariate_effects=(0.5, np.inf), seed=42 - ) + generate_survey_did_data(add_covariates=True, covariate_effects=(0.5, np.inf), seed=42) def test_te_covariate_interaction_validation(self): """te_covariate_interaction must be finite.""" from diff_diff.prep_dgp import generate_survey_did_data with pytest.raises(ValueError, match="te_covariate_interaction must be finite"): - generate_survey_did_data( - add_covariates=True, te_covariate_interaction=np.nan, seed=42 + generate_survey_did_data(add_covariates=True, te_covariate_interaction=np.nan, seed=42) + + +class TestAggregateSurvey: + """Tests for aggregate_survey function.""" + + @pytest.fixture + def micro_data(self): + """Create simple microdata: 2 states, 2 years, with survey design.""" + rng = np.random.RandomState(42) + n = 400 + states = np.repeat(["CA", "TX"], n // 2) + years = np.tile(np.repeat([2019, 2020], n // 4), 2) + strata = np.repeat(np.arange(4), n // 4) + psu = np.arange(n) // 10 # 10 obs per PSU, 40 PSUs total + weights = rng.uniform(0.5, 3.0, n) + outcome = rng.normal(10, 2, n) + # Make CA slightly higher than TX to get different means + outcome[: n // 2] += 2.0 + covariate = rng.normal(50, 10, n) + + return pd.DataFrame( + { + "state": states, + "year": years, + "stratum": strata, + "cluster": psu, + "wt": weights, + "y": outcome, + "x": covariate, + } + ) + + @pytest.fixture + def design(self): + return SurveyDesign(weights="wt", strata="stratum", psu="cluster") + + def test_basic_aggregation(self, micro_data, design): + """Design-weighted means should differ from simple means.""" + panel, _ = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + assert len(panel) == 4 # 2 states x 2 years + + # Check design-weighted mean differs from simple mean + simple_mean = micro_data[micro_data["state"] == "CA"][micro_data["year"] == 2019][ + "y" + ].mean() + ca_2019 = panel[(panel["state"] == "CA") & (panel["year"] == 2019)] + weighted_mean = ca_2019["y_mean"].iloc[0] + # With non-uniform weights, these should differ + assert weighted_mean != pytest.approx(simple_mean, abs=0.01) + + def test_column_naming(self, micro_data, design): + """All expected columns should be present.""" + panel, _ = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + covariates="x", + survey_design=design, + ) + expected = { + "state", + "year", + "y_mean", + "y_se", + "y_n", + "y_precision", + "x_mean", + "cell_n", + "cell_n_eff", + "srs_fallback", + } + assert expected.issubset(set(panel.columns)) + + def test_multiple_outcomes(self, micro_data, design): + """Each outcome gets own columns; SurveyDesign uses first.""" + micro_data = micro_data.copy() + micro_data["y2"] = micro_data["y"] * 2 + panel, stage2 = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes=["y", "y2"], + survey_design=design, + ) + assert "y_mean" in panel.columns + assert "y2_mean" in panel.columns + assert "y_precision" in panel.columns + assert "y2_precision" in panel.columns + assert stage2.weights == "y_weight" + + def test_multi_outcome_filtering_contract(self): + """Multi-outcome filtering is based on first outcome; warns about secondary data loss.""" + rng = np.random.RandomState(33) + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], 20), + "time": np.ones(40, dtype=int), + "wt": np.ones(40), + "y1": np.concatenate( + [[np.nan] * 20, rng.normal(10, 2, 20)] + ), # A: all-NaN, B: valid + "y2": rng.normal(5, 1, 40), # valid everywhere + } + ) + design = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="y2.*valid data in dropped"): + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes=["y1", "y2"], + survey_design=design, + ) + # Cell A dropped (y1 non-estimable), even though y2 was valid + assert len(panel) == 1 + assert panel["geo"].iloc[0] == "B" + + def test_covariates_mean_only(self, micro_data, design): + """Covariates get mean column only, no SE/precision.""" + panel, _ = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + covariates="x", + survey_design=design, + ) + assert "x_mean" in panel.columns + assert "x_se" not in panel.columns + assert "x_precision" not in panel.columns + + def test_returned_survey_design(self, micro_data, design): + """Returned SurveyDesign has correct aweight config and clustering.""" + _, stage2 = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + assert stage2.weight_type == "aweight" + assert stage2.weights == "y_weight" + assert stage2.psu == "state" + + def test_srs_fallback(self): + """Cells where design-based variance fails get SRS fallback.""" + # Create data where each cell has only 1 PSU per stratum + rng = np.random.RandomState(99) + n = 40 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.tile(np.repeat([0, 1], n // 4), 2), + "stratum": np.arange(n), # every obs is its own stratum + "psu": np.arange(n), # every obs is its own PSU + "wt": np.ones(n), + "y": rng.normal(0, 1, n), + } + ) + design = SurveyDesign(weights="wt", strata="stratum", psu="psu", lonely_psu="remove") + with pytest.warns(UserWarning, match="SRS fallback"): + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + assert panel["srs_fallback"].all() + # SRS SE should be finite and positive + assert (panel["y_se"] > 0).all() + assert panel["y_se"].notna().all() + + # Verify SRS SE matches manual computation for one cell + cell = data[(data["geo"] == "A") & (data["time"] == 0)] + y_vals = cell["y"].values + n_cell = len(y_vals) + expected_var = np.var(y_vals, ddof=0) / n_cell * n_cell / (n_cell - 1) + expected_se = np.sqrt(expected_var) + actual_se = panel[(panel["geo"] == "A") & (panel["time"] == 0)]["y_se"].iloc[0] + assert actual_se == pytest.approx(expected_se, rel=1e-10) + + def test_missing_values(self, micro_data, design): + """Missing values reduce var_n but cell_n stays the same.""" + micro_data = micro_data.copy() + # Set some values to NaN in CA-2019 + mask = (micro_data["state"] == "CA") & (micro_data["year"] == 2019) + idx = micro_data[mask].index[:5] + micro_data.loc[idx, "y"] = np.nan + + panel, _ = aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + ca_2019 = panel[(panel["state"] == "CA") & (panel["year"] == 2019)] + assert ca_2019["y_n"].iloc[0] == 100 - 5 # 5 NaN + assert ca_2019["cell_n"].iloc[0] == 100 # all respondents + + def test_zero_variance_cell(self): + """When all values are identical, precision is NaN.""" + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], 20), + "time": np.tile(np.repeat([0, 1], 10), 2), + "wt": np.ones(40), + "y": np.concatenate( + [ + np.full(10, 5.0), # A-0: constant + np.random.RandomState(1).normal(5, 1, 10), # A-1 + np.random.RandomState(2).normal(5, 1, 10), # B-0 + np.random.RandomState(3).normal(5, 1, 10), # B-1 + ] + ), + } + ) + design = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="Zero variance"): + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + a0 = panel[(panel["geo"] == "A") & (panel["time"] == 0)] + assert a0["y_mean"].iloc[0] == pytest.approx(5.0) + assert a0["y_se"].iloc[0] == pytest.approx(0.0) + assert np.isnan(a0["y_precision"].iloc[0]) + + def test_lonely_psu_override(self): + """lonely_psu parameter overrides survey_design setting.""" + rng = np.random.RandomState(77) + n = 40 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.tile(np.repeat([0, 1], n // 4), 2), + "stratum": np.repeat(np.arange(4), n // 4), + "psu": np.arange(n) // 5, + "wt": np.ones(n), + "y": rng.normal(0, 1, n), + } + ) + design = SurveyDesign(weights="wt", strata="stratum", psu="psu", lonely_psu="remove") + # Override to "certainty" — different behavior for singletons + panel_cert, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + lonely_psu="certainty", + ) + panel_remove, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + # Results should exist for both + assert len(panel_cert) == 4 + assert len(panel_remove) == 4 + + def test_single_by_column(self, micro_data, design): + """Single string for by works correctly.""" + panel, stage2 = aggregate_survey( + micro_data, + by="state", + outcomes="y", + survey_design=design, + ) + assert len(panel) == 2 # CA, TX + assert "state" in panel.columns + assert stage2.psu == "state" + + def test_srs_equivalence_weights_only(self): + """With no strata/PSU, SE matches weighted SRS formula.""" + rng = np.random.RandomState(123) + n = 100 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.ones(n, dtype=int), + "wt": rng.uniform(0.5, 2.0, n), + "y": rng.normal(10, 2, n), + } + ) + design = SurveyDesign(weights="wt") + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + + # Manual full-design domain estimation for cell A: + # psi is zero-padded to n_total; adjustment uses n_total/(n_total-1) + w_all = data["wt"].values + w_all_norm = w_all / w_all.mean() # resolve() normalizes to mean=1 + cell_mask = (data["geo"] == "A").values + y_all = data["y"].values + w_cell = w_all_norm[cell_mask] + y_cell = y_all[cell_mask] + sum_w = np.sum(w_cell) + y_bar = np.sum(w_cell * y_cell) / sum_w + + # Full-length psi with zeros outside cell + psi_full = np.zeros(n) + psi_full[cell_mask] = w_cell * (y_cell - y_bar) / sum_w + + # Implicit per-obs PSU with full-design adjustment + psi_mean = psi_full.mean() + centered = psi_full - psi_mean + variance = (n / (n - 1)) * np.sum(centered**2) + expected_se = np.sqrt(variance) + + actual_se = panel[panel["geo"] == "A"]["y_se"].iloc[0] + assert actual_se == pytest.approx(expected_se, rel=1e-10) + + def test_design_effect_increases_se(self): + """With PSU clustering, SE should be larger than without.""" + rng = np.random.RandomState(55) + n = 200 + psu_ids = np.arange(n) // 10 # 10 obs per PSU + # Add PSU-level random effects to create ICC + psu_effects = rng.normal(0, 3, 20) + y = rng.normal(0, 1, n) + psu_effects[psu_ids] + + data = pd.DataFrame( + { + "geo": ["A"] * n, + "time": np.ones(n, dtype=int), + "cluster": psu_ids, + "wt": np.ones(n), + "y": y, + } + ) + + design_no_psu = SurveyDesign(weights="wt") + design_psu = SurveyDesign(weights="wt", psu="cluster") + + panel_no_psu, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design_no_psu, + ) + panel_psu, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design_psu, + ) + + se_no_psu = panel_no_psu["y_se"].iloc[0] + se_psu = panel_psu["y_se"].iloc[0] + assert se_psu > se_no_psu # clustering increases SE + + def test_equal_weights_simple_mean(self): + """With equal weights, design-weighted mean equals arithmetic mean.""" + rng = np.random.RandomState(88) + n = 60 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.tile(np.repeat([0, 1], n // 4), 2), + "wt": np.ones(n), + "y": rng.normal(10, 2, n), + } + ) + design = SurveyDesign(weights="wt") + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + # Check A-0 cell + cell = data[(data["geo"] == "A") & (data["time"] == 0)] + assert panel[(panel["geo"] == "A") & (panel["time"] == 0)]["y_mean"].iloc[0] == ( + pytest.approx(cell["y"].mean(), rel=1e-12) + ) + + def test_pipeline_with_did(self): + """Full pipeline: microdata → aggregate → DiD estimation.""" + from diff_diff import DifferenceInDifferences + + # Construct microdata simulating 4 states, 2 periods, ~50 obs/cell + rng = np.random.RandomState(42) + rows = [] + for state in range(4): + treated = 1 if state < 2 else 0 + for period in [0, 1]: + n_cell = rng.randint(40, 60) + # Treatment effect in post period for treated states + te = 3.0 if (treated and period == 1) else 0.0 + for _ in range(n_cell): + strat = rng.randint(0, 3) + psu_id = state * 100 + strat * 10 + rng.randint(0, 3) + rows.append( + { + "state": state, + "period": period, + "stratum": strat, + "psu": psu_id, + "wt": rng.uniform(0.5, 2.0), + "outcome": rng.normal(10 + te, 2), + "treated": treated, + } + ) + micro = pd.DataFrame(rows) + + design = SurveyDesign(weights="wt", strata="stratum", psu="psu") + panel, stage2 = aggregate_survey( + micro, + by=["state", "period"], + outcomes="outcome", + covariates="treated", + survey_design=design, + ) + + panel["treated_bin"] = (panel["treated_mean"] > 0.5).astype(int) + + did = DifferenceInDifferences() + result = did.fit( + panel, + outcome="outcome_mean", + treatment="treated_bin", + time="period", + survey_design=stage2, + ) + assert result.att is not None + assert np.isfinite(result.att) + assert np.isfinite(result.se) + assert result.se > 0 + # ATT should be near the true effect of 3.0 + assert abs(result.att - 3.0) < 2.0 + + # --- Error tests --- + + def test_error_missing_column(self, micro_data, design): + """Missing column raises ValueError.""" + with pytest.raises(ValueError, match="Columns not found"): + aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="nonexistent", + survey_design=design, + ) + + def test_error_invalid_survey_design(self, micro_data): + """Non-SurveyDesign object raises TypeError.""" + with pytest.raises(TypeError, match="SurveyDesign instance"): + aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design="not_a_design", ) + + def test_error_min_n_too_small(self, micro_data, design): + """min_n < 1 raises ValueError.""" + with pytest.raises(ValueError, match="min_n must be >= 1"): + aggregate_survey( + micro_data, + by=["state", "year"], + outcomes="y", + survey_design=design, + min_n=0, + ) + + def test_error_empty_by(self, micro_data, design): + """Empty by list raises ValueError.""" + with pytest.raises(ValueError, match="at least one grouping column"): + aggregate_survey(micro_data, by=[], outcomes="y", survey_design=design) + + def test_error_empty_outcomes(self, micro_data, design): + """Empty outcomes list raises ValueError.""" + with pytest.raises(ValueError, match="at least one outcome"): + aggregate_survey( + micro_data, + by=["state", "year"], + outcomes=[], + survey_design=design, + ) + + def test_error_non_numeric_outcome(self, micro_data, design): + """Non-numeric outcome column raises ValueError.""" + data = micro_data.copy() + data["label"] = "foo" + with pytest.raises(ValueError, match="Non-numeric column"): + aggregate_survey( + data, + by=["state", "year"], + outcomes="label", + survey_design=design, + ) + + def test_nullable_numeric_dtypes(self): + """Pandas nullable Int64/Float64 dtypes are accepted as numeric.""" + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], 10), + "time": np.ones(20, dtype=int), + "wt": np.ones(20), + "y": pd.array(np.random.RandomState(1).normal(0, 1, 20), dtype="Float64"), + } + ) + design = SurveyDesign(weights="wt") + panel, _ = aggregate_survey(data, by=["geo", "time"], outcomes="y", survey_design=design) + assert len(panel) == 2 + assert panel["y_mean"].notna().all() + + def test_error_empty_data(self, design): + """Empty DataFrame raises ValueError.""" + empty = pd.DataFrame(columns=["state", "year", "y", "wt", "stratum", "cluster"]) + with pytest.raises(ValueError, match="data must be non-empty"): + aggregate_survey( + empty, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + + def test_error_missing_grouping_keys(self, micro_data, design): + """NaN in grouping columns raises ValueError.""" + data = micro_data.copy() + data.loc[0, "state"] = np.nan + with pytest.raises(ValueError, match="Missing values in grouping column"): + aggregate_survey( + data, + by=["state", "year"], + outcomes="y", + survey_design=design, + ) + + def test_error_all_missing_grouping_keys(self, design): + """All-NaN grouping column raises ValueError.""" + data = pd.DataFrame( + { + "state": [np.nan] * 10, + "year": np.ones(10, dtype=int), + "y": np.random.RandomState(1).normal(0, 1, 10), + "wt": np.ones(10), + } + ) + design_simple = SurveyDesign(weights="wt") + with pytest.raises(ValueError, match="Missing values in grouping column"): + aggregate_survey( + data, + by=["state", "year"], + outcomes="y", + survey_design=design_simple, + ) + + def test_stage2_handoff_with_nonfinite_cells(self): + """Non-estimable cells are dropped; stage2 works with fit().""" + from diff_diff import DifferenceInDifferences + + rng = np.random.RandomState(99) + rows = [] + for state in range(4): + treated = 1 if state < 2 else 0 + for period in [0, 1]: + te = 3.0 if (treated and period == 1) else 0.0 + n_cell = 30 + for _ in range(n_cell): + rows.append( + { + "state": state, + "period": period, + "wt": rng.uniform(0.5, 2.0), + "outcome": rng.normal(10 + te, 2), + "treated": treated, + } + ) + micro = pd.DataFrame(rows) + # Make one cell all-NaN outcome → n_valid=0 → NaN mean → dropped + mask = (micro["state"] == 0) & (micro["period"] == 0) + micro.loc[mask, "outcome"] = np.nan + + design = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="non-estimable"): + panel, stage2 = aggregate_survey( + micro, + by=["state", "period"], + outcomes="outcome", + covariates="treated", + survey_design=design, + ) + + # Non-estimable cell should be dropped from panel + assert len(panel) == 7 # 8 cells - 1 dropped + assert not ((panel["state"] == 0) & (panel["period"] == 0)).any() + + # No NaN in outcome mean or weight columns + assert panel["outcome_mean"].notna().all() + assert panel["outcome_weight"].notna().all() + + # stage2 should work with fit() + panel["treated_bin"] = (panel["treated_mean"] > 0.5).astype(int) + did = DifferenceInDifferences() + result = did.fit( + panel, + outcome="outcome_mean", + treatment="treated_bin", + time="period", + survey_design=stage2, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) + + def test_zero_weight_psu_dropped(self): + """Geographic units with zero total weight are dropped from panel.""" + rng = np.random.RandomState(88) + # State 0: all cells have only 1 valid obs → NaN precision → weight=0 + # State 1-3: normal cells + rows = [] + for state in range(4): + for period in [0, 1]: + if state == 0: + # 1 obs per cell → NaN SE → weight=0 + rows.append( + { + "state": state, + "period": period, + "wt": 1.0, + "y": rng.normal(10, 2), + } + ) + else: + for _ in range(20): + rows.append( + { + "state": state, + "period": period, + "wt": 1.0, + "y": rng.normal(10, 2), + } + ) + data = pd.DataFrame(rows) + design = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="zero total weight"): + panel, _ = aggregate_survey( + data, by=["state", "period"], outcomes="y", survey_design=design + ) + # State 0 should be entirely gone + assert 0 not in panel["state"].values + assert len(panel) == 6 # 3 states × 2 periods + + def test_error_all_cells_dropped(self): + """All cells non-estimable raises ValueError.""" + data = pd.DataFrame( + { + "state": ["A"] * 5, + "period": np.ones(5, dtype=int), + "wt": np.ones(5), + "y": [np.nan] * 5, + } + ) + design = SurveyDesign(weights="wt") + with pytest.raises(ValueError, match="No estimable cells remain"): + aggregate_survey(data, by=["state", "period"], outcomes="y", survey_design=design) + + def test_zero_weight_rows_excluded_from_n_valid(self): + """Zero-weight rows should not count as valid observations.""" + rng = np.random.RandomState(66) + # Cell A: 3 positive-weight obs + 7 zero-weight padding + # n_valid should be 3, not 10 + data = pd.DataFrame( + { + "geo": ["A"] * 10 + ["B"] * 10, + "time": np.ones(20, dtype=int), + "wt": np.concatenate( + [ + np.array([1.0, 1.0, 1.0] + [0.0] * 7), # A: 3 real + np.ones(10), # B: all real + ] + ), + "y": rng.normal(10, 2, 20), + } + ) + design = SurveyDesign(weights="wt") + panel, _ = aggregate_survey(data, by=["geo", "time"], outcomes="y", survey_design=design) + cell_a = panel[panel["geo"] == "A"] + # Only 3 positive-weight obs → n_valid=3 + assert cell_a["y_n"].iloc[0] == 3 + + cell_b = panel[panel["geo"] == "B"] + # 10 positive-weight obs + assert cell_b["y_n"].iloc[0] == 10 + assert cell_b["y_se"].iloc[0] > 0 + + def test_duplicate_index(self): + """Duplicate DataFrame indices do not break aggregation.""" + rng = np.random.RandomState(77) + n = 40 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.tile(np.repeat([0, 1], n // 4), 2), + "wt": np.ones(n), + "y": rng.normal(10, 2, n), + } + ) + # Create duplicate indices (e.g., from concat without reset_index) + data.index = list(range(n // 2)) * 2 # 0..19, 0..19 + + design = SurveyDesign(weights="wt") + panel_dup, _ = aggregate_survey( + data, by=["geo", "time"], outcomes="y", survey_design=design + ) + + # Compare against clean-index version + data_clean = data.reset_index(drop=True) + panel_clean, _ = aggregate_survey( + data_clean, by=["geo", "time"], outcomes="y", survey_design=design + ) + + # Results should be identical + np.testing.assert_allclose( + panel_dup["y_mean"].values, panel_clean["y_mean"].values, rtol=1e-12 + ) + np.testing.assert_allclose(panel_dup["y_se"].values, panel_clean["y_se"].values, rtol=1e-12) + + def test_domain_estimation_preserves_full_design(self): + """Full-design domain estimation accounts for PSUs outside the cell. + + Stratum 0 has PSUs {0, 1}. Cell A contains only PSU 0. + With physical subsetting, stratum 0 would be a singleton → skipped. + With full-design domain estimation, both PSUs participate → non-zero + stratum variance contribution and no SRS fallback. + """ + rng = np.random.RandomState(42) + # 2 strata, 2 PSUs each, 5 obs per PSU = 20 obs total + # Cell A = first 10 obs (stratum 0 PSU 0 + stratum 1 PSU 2) + # Cell B = last 10 obs (stratum 0 PSU 1 + stratum 1 PSU 3) + data = pd.DataFrame( + { + "geo": ["A"] * 10 + ["B"] * 10, + "time": np.ones(20, dtype=int), + "stratum": np.repeat([0, 0, 1, 1], 5), + "psu": np.repeat([0, 1, 2, 3], 5), + "wt": np.ones(20), + "y": rng.normal(10, 2, 20), + } + ) + # Reassign so cell A has only PSU 0 from stratum 0, cell B has only PSU 1 + data.loc[:4, "geo"] = "A" # stratum 0, PSU 0 + data.loc[5:9, "geo"] = "B" # stratum 0, PSU 1 + data.loc[10:14, "geo"] = "A" # stratum 1, PSU 2 + data.loc[15:19, "geo"] = "B" # stratum 1, PSU 3 + + design = SurveyDesign(weights="wt", strata="stratum", psu="psu", lonely_psu="remove") + panel, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + ) + + cell_a = panel[panel["geo"] == "A"] + # With full-design domain estimation: + # - Both PSUs in each stratum participate (one with zero psi) + # - No singleton PSU → no SRS fallback needed + assert not cell_a["srs_fallback"].iloc[0] + assert cell_a["y_se"].iloc[0] > 0 + assert np.isfinite(cell_a["y_se"].iloc[0]) + assert np.isfinite(cell_a["y_precision"].iloc[0]) + + def test_min_n_forces_srs_fallback(self): + """min_n parameter forces SRS fallback for small cells.""" + rng = np.random.RandomState(44) + n = 40 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B"], n // 2), + "time": np.ones(n, dtype=int), + "wt": np.ones(n), + "y": rng.normal(10, 2, n), + } + ) + design = SurveyDesign(weights="wt") + + # min_n=30 → cells with 20 obs each should use SRS fallback + with pytest.warns(UserWarning, match="SRS fallback"): + panel_high, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + min_n=30, + ) + assert panel_high["srs_fallback"].all() + + # min_n=1 → should use design-based variance (no fallback) + panel_low, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design, + min_n=1, + ) + assert not panel_low["srs_fallback"].any() + + # SEs should differ between the two + se_high = panel_high[panel_high["geo"] == "A"]["y_se"].iloc[0] + se_low = panel_low[panel_low["geo"] == "A"]["y_se"].iloc[0] + assert se_high != pytest.approx(se_low, rel=1e-6) + + def test_replicate_weight_aggregation(self): + """Aggregation with replicate weight designs produces valid SEs.""" + from diff_diff.prep_dgp import generate_survey_did_data + + micro = generate_survey_did_data( + n_units=200, + n_periods=4, + cohort_periods=[3], + n_strata=3, + psu_per_stratum=6, + include_replicate_weights=True, + panel=False, + seed=42, + ) + # Build replicate weight column list + rep_cols = [c for c in micro.columns if c.startswith("rep_")] + design = SurveyDesign( + weights="weight", + replicate_weights=rep_cols, + replicate_method="JK1", + ) + panel, _ = aggregate_survey( + micro, + by=["stratum", "period"], + outcomes="outcome", + survey_design=design, + ) + # All cells should have finite, positive SEs + assert panel["outcome_se"].notna().all() + assert (panel["outcome_se"] > 0).all() + + def test_replicate_weight_min_n_fallback(self): + """SRS fallback works correctly under replicate-weight designs.""" + from diff_diff.prep_dgp import generate_survey_did_data + + micro = generate_survey_did_data( + n_units=200, + n_periods=4, + cohort_periods=[3], + n_strata=3, + psu_per_stratum=6, + include_replicate_weights=True, + panel=False, + seed=42, + ) + rep_cols = [c for c in micro.columns if c.startswith("rep_")] + design = SurveyDesign( + weights="weight", + replicate_weights=rep_cols, + replicate_method="JK1", + ) + + # min_n high enough to force SRS fallback on all cells + with pytest.warns(UserWarning, match="SRS fallback"): + panel_srs, _ = aggregate_survey( + micro, + by=["stratum", "period"], + outcomes="outcome", + survey_design=design, + min_n=9999, + ) + assert panel_srs["srs_fallback"].all() + assert panel_srs["outcome_se"].notna().all() + assert (panel_srs["outcome_se"] > 0).all() + + # Default min_n → replicate-based variance (no fallback) + panel_rep, _ = aggregate_survey( + micro, + by=["stratum", "period"], + outcomes="outcome", + survey_design=design, + ) + assert not panel_rep["srs_fallback"].any() + + # SEs should differ between SRS fallback and replicate-based + assert not np.allclose( + panel_srs["outcome_se"].values, + panel_rep["outcome_se"].values, + rtol=1e-6, + ) + + def test_srs_fallback_scale_invariant(self): + """SRS fallback SEs are invariant to constant weight rescaling.""" + rng = np.random.RandomState(55) + n = 60 + data = pd.DataFrame( + { + "geo": np.repeat(["A", "B", "C"], n // 3), + "time": np.ones(n, dtype=int), + "wt": rng.uniform(0.5, 2.0, n), + "y": rng.normal(10, 2, n), + } + ) + design1 = SurveyDesign(weights="wt") + + # Force SRS fallback with high min_n + with pytest.warns(UserWarning, match="SRS fallback"): + panel1, _ = aggregate_survey( + data, + by=["geo", "time"], + outcomes="y", + survey_design=design1, + min_n=9999, + ) + + # Rescale weights by 5x → should give identical SEs + data2 = data.copy() + data2["wt"] = data2["wt"] * 5.0 + design2 = SurveyDesign(weights="wt") + with pytest.warns(UserWarning, match="SRS fallback"): + panel2, _ = aggregate_survey( + data2, + by=["geo", "time"], + outcomes="y", + survey_design=design2, + min_n=9999, + ) + + np.testing.assert_allclose(panel1["y_se"].values, panel2["y_se"].values, rtol=1e-10) + np.testing.assert_allclose(panel1["y_mean"].values, panel2["y_mean"].values, rtol=1e-10)