Skip to content

Commit d330d36

Browse files
igerberclaude
andcommitted
Fix P0 domain estimation, P1 min_n/empty guard, P3 docs/tests
Rework aggregate_survey() to use full-design domain estimation: zero-pad influence functions outside each cell, preserving full strata/PSU structure for variance computation per Lumley (2004) Section 3.4 and the library's subpopulation() methodology. Also fix: min_n parameter now operative (forces SRS fallback), empty-input guard added, docstring examples corrected, REGISTRY.md rewritten with domain estimation language, 4 new tests added (domain regression, min_n behavior, empty input, replicate weights). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2994d49 commit d330d36

4 files changed

Lines changed: 238 additions & 86 deletions

File tree

diff_diff/prep.py

Lines changed: 80 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,64 +1313,85 @@ def trim_weights(
13131313

13141314

13151315
def _cell_mean_variance(
1316-
y: np.ndarray,
1317-
weights: np.ndarray,
1318-
cell_resolved: ResolvedSurveyDesign,
1316+
y_full: np.ndarray,
1317+
full_resolved: ResolvedSurveyDesign,
1318+
cell_mask: np.ndarray,
1319+
min_n: int,
13191320
) -> Tuple[float, float, int, bool]:
13201321
"""Compute design-based mean and variance of the weighted mean for one cell.
13211322
1323+
Uses full-design domain estimation: the influence function is zero-padded
1324+
outside the cell, preserving the full strata/PSU structure for variance
1325+
estimation. This is the methodologically correct approach for domain
1326+
estimation under complex survey designs (Lumley 2004, Section 3.4).
1327+
13221328
Parameters
13231329
----------
1324-
y : np.ndarray
1325-
Outcome values for the cell (may contain NaN).
1326-
weights : np.ndarray
1327-
Resolved weights for the cell (already extracted from ResolvedSurveyDesign).
1328-
cell_resolved : ResolvedSurveyDesign
1329-
Resolved survey design subsetted to this cell.
1330+
y_full : np.ndarray
1331+
Outcome values for the full dataset (may contain NaN).
1332+
full_resolved : ResolvedSurveyDesign
1333+
Full-sample resolved survey design.
1334+
cell_mask : np.ndarray
1335+
Boolean mask identifying cell members in the full dataset.
1336+
min_n : int
1337+
Minimum valid observations for design-based variance. Below this
1338+
threshold, SRS fallback is used.
13301339
13311340
Returns
13321341
-------
13331342
mean : float
13341343
Design-weighted cell mean.
13351344
variance : float
13361345
Design-based variance of the cell mean (>= 0). Uses SRS fallback
1337-
when the design-based estimate is unidentifiable.
1346+
when the design-based estimate is unidentifiable or n_valid < min_n.
13381347
n_valid : int
1339-
Number of non-missing observations.
1348+
Number of non-missing observations in the cell.
13401349
used_srs_fallback : bool
13411350
True if SRS variance was used instead of design-based.
13421351
"""
1343-
valid = ~np.isnan(y)
1352+
y_cell = y_full[cell_mask]
1353+
w_cell = full_resolved.weights[cell_mask]
1354+
valid = ~np.isnan(y_cell)
13441355
n_valid = int(np.sum(valid))
13451356

13461357
if n_valid == 0:
13471358
return np.nan, np.nan, 0, False
13481359

1349-
if n_valid == 1:
1350-
y_bar = float(y[valid][0])
1360+
if n_valid < 2:
1361+
y_bar = float(y_cell[valid][0])
13511362
return y_bar, np.nan, 1, False
13521363

1353-
# Zero out weights for NaN observations (subpopulation approach)
1354-
w = weights.copy()
1355-
y_clean = np.where(valid, y, 0.0)
1356-
w_valid = w * valid.astype(np.float64)
1357-
sum_w = np.sum(w_valid)
1364+
# Weighted mean from cell members (NaN-safe)
1365+
w_valid = w_cell * valid.astype(np.float64)
1366+
y_clean = np.where(valid, y_cell, 0.0)
1367+
sum_w = float(np.sum(w_valid))
13581368

13591369
if sum_w <= 0:
13601370
return np.nan, np.nan, n_valid, False
13611371

1362-
# Design-weighted mean
13631372
y_bar = float(np.sum(w_valid * y_clean) / sum_w)
13641373

1365-
# Influence function: psi_i = w_i * (y_i - y_bar) / sum(w)
1366-
psi = w_valid * (y_clean - y_bar) / sum_w
1367-
1368-
# Route to TSL or replicate variance
1374+
# SRS fallback if below min_n threshold
13691375
used_srs = False
1370-
if cell_resolved.uses_replicate_variance:
1371-
variance, _ = compute_replicate_if_variance(psi, cell_resolved)
1376+
if n_valid < min_n:
1377+
resid_sq = w_valid * (y_clean - y_bar) ** 2
1378+
variance = float(np.sum(resid_sq) / (sum_w**2) * n_valid / (n_valid - 1))
1379+
return y_bar, max(variance, 0.0), n_valid, True
1380+
1381+
# Full-design domain estimation: construct full-length psi with zeros
1382+
# outside the cell, preserving full strata/PSU structure for variance
1383+
n_total = len(y_full)
1384+
psi = np.zeros(n_total)
1385+
# Positions in full array where cell member has valid data
1386+
cell_indices = np.where(cell_mask)[0]
1387+
valid_positions = cell_indices[valid]
1388+
psi[valid_positions] = w_valid[valid] * (y_clean[valid] - y_bar) / sum_w
1389+
1390+
# Route to TSL or replicate variance using the full design
1391+
if full_resolved.uses_replicate_variance:
1392+
variance, _ = compute_replicate_if_variance(psi, full_resolved)
13721393
else:
1373-
variance = compute_survey_if_variance(psi, cell_resolved)
1394+
variance = compute_survey_if_variance(psi, full_resolved)
13741395

13751396
# SRS fallback when design-based variance is unidentifiable
13761397
if np.isnan(variance):
@@ -1397,9 +1418,10 @@ def aggregate_survey(
13971418
columns. Returns a panel-ready DataFrame with precision weights and a
13981419
pre-configured :class:`SurveyDesign` for second-stage DiD estimation.
13991420
1400-
This follows R's ``survey::svyby()`` pattern: the survey design is
1401-
subsetted to each cell and domain-level statistics are computed using
1402-
the within-cell strata/PSU structure.
1421+
Each cell is treated as a subpopulation/domain of the full survey
1422+
design: influence function values are zero-padded outside the cell,
1423+
preserving full strata/PSU structure for variance estimation per
1424+
Lumley (2004) Section 3.4.
14031425
14041426
Parameters
14051427
----------
@@ -1446,7 +1468,7 @@ def aggregate_survey(
14461468
... )
14471469
>>> result = DifferenceInDifferences().fit(
14481470
... panel, outcome="smoking_rate_mean",
1449-
... treatment="treated", time="post", survey_design=stage2,
1471+
... treatment="treated", time="year", survey_design=stage2,
14501472
... )
14511473
"""
14521474
import warnings
@@ -1482,12 +1504,21 @@ def aggregate_survey(
14821504
f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{lonely_psu}'"
14831505
)
14841506

1507+
# --- Empty-input guard ---
1508+
if data.empty:
1509+
raise ValueError("data must be non-empty")
1510+
14851511
# --- Resolve design once on full data ---
14861512
effective_design = (
14871513
replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design
14881514
)
14891515
full_resolved = effective_design.resolve(data)
14901516

1517+
# --- Precompute full-length outcome/covariate arrays ---
1518+
n_total = len(data)
1519+
all_vars = outcome_cols + cov_cols
1520+
y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars}
1521+
14911522
# --- Per-cell computation ---
14921523
grouped = data.groupby(by_cols, sort=True)
14931524
rows: List[Dict[str, Any]] = []
@@ -1496,32 +1527,17 @@ def aggregate_survey(
14961527

14971528
for cell_key, cell_df in grouped:
14981529
cell_idx = np.array(cell_df.index)
1499-
# Convert to positional indices for array subsetting
15001530
pos_idx = data.index.get_indexer(cell_idx)
15011531

1502-
cell_n = len(pos_idx)
1503-
cell_key_str = str(cell_key)
1532+
# Boolean mask for full-design domain estimation
1533+
cell_mask = np.zeros(n_total, dtype=bool)
1534+
cell_mask[pos_idx] = True
15041535

1505-
# Subset arrays from full resolved design
1506-
cell_w = full_resolved.weights[pos_idx]
1507-
cell_strata = full_resolved.strata[pos_idx] if full_resolved.strata is not None else None
1508-
cell_psu = full_resolved.psu[pos_idx] if full_resolved.psu is not None else None
1509-
cell_fpc = full_resolved.fpc[pos_idx] if full_resolved.fpc is not None else None
1510-
1511-
cell_n_strata = int(len(np.unique(cell_strata))) if cell_strata is not None else 0
1512-
cell_n_psu = int(len(np.unique(cell_psu))) if cell_psu is not None else 0
1513-
1514-
cell_resolved = full_resolved.subset_to_units(
1515-
row_idx=pos_idx,
1516-
weights=cell_w,
1517-
strata=cell_strata,
1518-
psu=cell_psu,
1519-
fpc=cell_fpc,
1520-
n_strata=cell_n_strata,
1521-
n_psu=cell_n_psu,
1522-
)
1536+
cell_n = int(np.sum(cell_mask))
1537+
cell_key_str = str(cell_key)
15231538

1524-
# Cell-level statistics
1539+
# Cell-level statistics (Kish ESS is a property of the cell)
1540+
cell_w = full_resolved.weights[cell_mask]
15251541
sum_w = float(np.sum(cell_w))
15261542
sum_w2 = float(np.sum(cell_w**2))
15271543
cell_n_eff = (sum_w**2 / sum_w2) if sum_w2 > 0 else 0.0
@@ -1539,10 +1555,14 @@ def aggregate_survey(
15391555

15401556
cell_srs_fallback = False
15411557

1542-
# Outcomes: mean + SE + n + precision
1558+
# Outcomes: mean + SE + n + precision (full-design domain estimation)
15431559
for var in outcome_cols:
1544-
y = cell_df[var].values.astype(np.float64)
1545-
y_bar, variance, n_valid, used_srs = _cell_mean_variance(y, cell_w, cell_resolved)
1560+
y_bar, variance, n_valid, used_srs = _cell_mean_variance(
1561+
y_arrays[var],
1562+
full_resolved,
1563+
cell_mask,
1564+
min_n,
1565+
)
15461566
se = float(np.sqrt(variance)) if not np.isnan(variance) else np.nan
15471567

15481568
if used_srs:
@@ -1562,14 +1582,14 @@ def aggregate_survey(
15621582
row[f"{var}_n"] = n_valid
15631583
row[f"{var}_precision"] = precision
15641584

1565-
# Covariates: mean only
1585+
# Covariates: design-weighted mean only
15661586
for var in cov_cols:
1567-
y = cell_df[var].values.astype(np.float64)
1568-
valid = ~np.isnan(y)
1587+
y_cell = y_arrays[var][cell_mask]
1588+
valid = ~np.isnan(y_cell)
15691589
w_valid = cell_w * valid.astype(np.float64)
15701590
sw = float(np.sum(w_valid))
15711591
if sw > 0:
1572-
row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y, 0.0)) / sw)
1592+
row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y_cell, 0.0)) / sw)
15731593
else:
15741594
row[f"{var}_mean"] = np.nan
15751595

docs/api/prep.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ Example
288288
panel,
289289
outcome="smoking_rate_mean",
290290
treatment="treated",
291-
time="post",
291+
time="year",
292292
survey_design=stage2,
293293
)
294294

docs/methodology/REGISTRY.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,22 +2310,21 @@ design-based precision estimates, for use as a pre-processing step before panel
23102310
DiD estimation on repeated cross-section survey data.
23112311

23122312
- **Reference**: Lumley (2004) "Analysis of Complex Survey Samples", Journal of
2313-
Statistical Software 9(8). R `survey::svyby()` implements similar per-group
2314-
survey estimation.
2313+
Statistical Software 9(8), Section 3.4 (domain estimation).
23152314
- **Cell mean**: Design-weighted mean `ȳ_g = Σ w_i y_i / Σ w_i` for each cell g
23162315
defined by grouping columns (e.g., state × year).
2317-
- **Cell variance**: Linearized influence function `ψ_i = w_i (y_i - ȳ_g) / Σ w_j`,
2318-
then design-based variance via `compute_survey_if_variance()` (TSL) or
2319-
`compute_replicate_if_variance()` (replicate designs). This is the standard
2320-
Horvitz-Thompson linearization for a ratio estimator.
2316+
- **Cell variance**: Each cell is treated as a subpopulation/domain of the full
2317+
survey design (consistent with `SurveyDesign.subpopulation()` and the
2318+
Subpopulation Analysis section below). The influence function
2319+
`ψ_i = w_i (y_i - ȳ_g) / Σ w_j` is zero-padded outside the cell, preserving
2320+
full strata/PSU structure for variance estimation via `compute_survey_if_variance()`
2321+
(TSL) or `compute_replicate_if_variance()` (replicate designs).
23212322
- **Precision weight**: `1 / V(ȳ_g)` used as inverse-variance weight (aweight)
23222323
in second-stage DiD estimation.
2323-
- **Note:** SRS fallback when design-based variance is unidentifiable within a cell
2324-
(e.g., all strata have singleton PSUs after cell subsetting). Formula:
2325-
`V_SRS = Σ w_i(y_i - ȳ)² / (Σ w_j)² × n/(n-1)`. Cells using SRS fallback
2326-
are flagged via `srs_fallback` column.
2327-
- **Note:** FPC values are passed through unchanged from the full design to cell
2328-
subsets — they represent population N_h per stratum, not per cell.
2324+
- **Note:** SRS fallback when design-based variance is unidentifiable (e.g., all
2325+
strata contribute zero variance) or when the cell has fewer than `min_n` valid
2326+
observations. Formula: `V_SRS = Σ w_i(y_i - ȳ)² / (Σ w_j)² × n/(n-1)`.
2327+
Cells using SRS fallback are flagged via `srs_fallback` column.
23292328
- **Edge case**: Zero-variance cells (all observations identical) set precision to
23302329
NaN to avoid infinite weights in second-stage WLS.
23312330

0 commit comments

Comments
 (0)