@@ -1313,64 +1313,85 @@ def trim_weights(
13131313
13141314
13151315def _cell_mean_variance (
1316- y : np .ndarray ,
1317- weights : np .ndarray ,
1318- cell_resolved : ResolvedSurveyDesign ,
1316+ y_full : np .ndarray ,
1317+ full_resolved : ResolvedSurveyDesign ,
1318+ cell_mask : np .ndarray ,
1319+ min_n : int ,
13191320) -> Tuple [float , float , int , bool ]:
13201321 """Compute design-based mean and variance of the weighted mean for one cell.
13211322
1323+ Uses full-design domain estimation: the influence function is zero-padded
1324+ outside the cell, preserving the full strata/PSU structure for variance
1325+ estimation. This is the methodologically correct approach for domain
1326+ estimation under complex survey designs (Lumley 2004, Section 3.4).
1327+
13221328 Parameters
13231329 ----------
1324- y : np.ndarray
1325- Outcome values for the cell (may contain NaN).
1326- weights : np.ndarray
1327- Resolved weights for the cell (already extracted from ResolvedSurveyDesign).
1328- cell_resolved : ResolvedSurveyDesign
1329- Resolved survey design subsetted to this cell.
1330+ y_full : np.ndarray
1331+ Outcome values for the full dataset (may contain NaN).
1332+ full_resolved : ResolvedSurveyDesign
1333+ Full-sample resolved survey design.
1334+ cell_mask : np.ndarray
1335+ Boolean mask identifying cell members in the full dataset.
1336+ min_n : int
1337+ Minimum valid observations for design-based variance. Below this
1338+ threshold, SRS fallback is used.
13301339
13311340 Returns
13321341 -------
13331342 mean : float
13341343 Design-weighted cell mean.
13351344 variance : float
13361345 Design-based variance of the cell mean (>= 0). Uses SRS fallback
1337- when the design-based estimate is unidentifiable.
1346+ when the design-based estimate is unidentifiable or n_valid < min_n .
13381347 n_valid : int
1339- Number of non-missing observations.
1348+ Number of non-missing observations in the cell .
13401349 used_srs_fallback : bool
13411350 True if SRS variance was used instead of design-based.
13421351 """
1343- valid = ~ np .isnan (y )
1352+ y_cell = y_full [cell_mask ]
1353+ w_cell = full_resolved .weights [cell_mask ]
1354+ valid = ~ np .isnan (y_cell )
13441355 n_valid = int (np .sum (valid ))
13451356
13461357 if n_valid == 0 :
13471358 return np .nan , np .nan , 0 , False
13481359
1349- if n_valid == 1 :
1350- y_bar = float (y [valid ][0 ])
1360+ if n_valid < 2 :
1361+ y_bar = float (y_cell [valid ][0 ])
13511362 return y_bar , np .nan , 1 , False
13521363
1353- # Zero out weights for NaN observations (subpopulation approach)
1354- w = weights .copy ()
1355- y_clean = np .where (valid , y , 0.0 )
1356- w_valid = w * valid .astype (np .float64 )
1357- sum_w = np .sum (w_valid )
1364+ # Weighted mean from cell members (NaN-safe)
1365+ w_valid = w_cell * valid .astype (np .float64 )
1366+ y_clean = np .where (valid , y_cell , 0.0 )
1367+ sum_w = float (np .sum (w_valid ))
13581368
13591369 if sum_w <= 0 :
13601370 return np .nan , np .nan , n_valid , False
13611371
1362- # Design-weighted mean
13631372 y_bar = float (np .sum (w_valid * y_clean ) / sum_w )
13641373
1365- # Influence function: psi_i = w_i * (y_i - y_bar) / sum(w)
1366- psi = w_valid * (y_clean - y_bar ) / sum_w
1367-
1368- # Route to TSL or replicate variance
1374+ # SRS fallback if below min_n threshold
13691375 used_srs = False
1370- if cell_resolved .uses_replicate_variance :
1371- variance , _ = compute_replicate_if_variance (psi , cell_resolved )
1376+ if n_valid < min_n :
1377+ resid_sq = w_valid * (y_clean - y_bar ) ** 2
1378+ variance = float (np .sum (resid_sq ) / (sum_w ** 2 ) * n_valid / (n_valid - 1 ))
1379+ return y_bar , max (variance , 0.0 ), n_valid , True
1380+
1381+ # Full-design domain estimation: construct full-length psi with zeros
1382+ # outside the cell, preserving full strata/PSU structure for variance
1383+ n_total = len (y_full )
1384+ psi = np .zeros (n_total )
1385+ # Positions in full array where cell member has valid data
1386+ cell_indices = np .where (cell_mask )[0 ]
1387+ valid_positions = cell_indices [valid ]
1388+ psi [valid_positions ] = w_valid [valid ] * (y_clean [valid ] - y_bar ) / sum_w
1389+
1390+ # Route to TSL or replicate variance using the full design
1391+ if full_resolved .uses_replicate_variance :
1392+ variance , _ = compute_replicate_if_variance (psi , full_resolved )
13721393 else :
1373- variance = compute_survey_if_variance (psi , cell_resolved )
1394+ variance = compute_survey_if_variance (psi , full_resolved )
13741395
13751396 # SRS fallback when design-based variance is unidentifiable
13761397 if np .isnan (variance ):
@@ -1397,9 +1418,10 @@ def aggregate_survey(
13971418 columns. Returns a panel-ready DataFrame with precision weights and a
13981419 pre-configured :class:`SurveyDesign` for second-stage DiD estimation.
13991420
1400- This follows R's ``survey::svyby()`` pattern: the survey design is
1401- subsetted to each cell and domain-level statistics are computed using
1402- the within-cell strata/PSU structure.
1421+ Each cell is treated as a subpopulation/domain of the full survey
1422+ design: influence function values are zero-padded outside the cell,
1423+ preserving full strata/PSU structure for variance estimation per
1424+ Lumley (2004) Section 3.4.
14031425
14041426 Parameters
14051427 ----------
@@ -1446,7 +1468,7 @@ def aggregate_survey(
14461468 ... )
14471469 >>> result = DifferenceInDifferences().fit(
14481470 ... panel, outcome="smoking_rate_mean",
1449- ... treatment="treated", time="post ", survey_design=stage2,
1471+ ... treatment="treated", time="year ", survey_design=stage2,
14501472 ... )
14511473 """
14521474 import warnings
@@ -1482,12 +1504,21 @@ def aggregate_survey(
14821504 f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{ lonely_psu } '"
14831505 )
14841506
1507+ # --- Empty-input guard ---
1508+ if data .empty :
1509+ raise ValueError ("data must be non-empty" )
1510+
14851511 # --- Resolve design once on full data ---
14861512 effective_design = (
14871513 replace (survey_design , lonely_psu = lonely_psu ) if lonely_psu else survey_design
14881514 )
14891515 full_resolved = effective_design .resolve (data )
14901516
1517+ # --- Precompute full-length outcome/covariate arrays ---
1518+ n_total = len (data )
1519+ all_vars = outcome_cols + cov_cols
1520+ y_arrays : Dict [str , np .ndarray ] = {var : data [var ].values .astype (np .float64 ) for var in all_vars }
1521+
14911522 # --- Per-cell computation ---
14921523 grouped = data .groupby (by_cols , sort = True )
14931524 rows : List [Dict [str , Any ]] = []
@@ -1496,32 +1527,17 @@ def aggregate_survey(
14961527
14971528 for cell_key , cell_df in grouped :
14981529 cell_idx = np .array (cell_df .index )
1499- # Convert to positional indices for array subsetting
15001530 pos_idx = data .index .get_indexer (cell_idx )
15011531
1502- cell_n = len (pos_idx )
1503- cell_key_str = str (cell_key )
1532+ # Boolean mask for full-design domain estimation
1533+ cell_mask = np .zeros (n_total , dtype = bool )
1534+ cell_mask [pos_idx ] = True
15041535
1505- # Subset arrays from full resolved design
1506- cell_w = full_resolved .weights [pos_idx ]
1507- cell_strata = full_resolved .strata [pos_idx ] if full_resolved .strata is not None else None
1508- cell_psu = full_resolved .psu [pos_idx ] if full_resolved .psu is not None else None
1509- cell_fpc = full_resolved .fpc [pos_idx ] if full_resolved .fpc is not None else None
1510-
1511- cell_n_strata = int (len (np .unique (cell_strata ))) if cell_strata is not None else 0
1512- cell_n_psu = int (len (np .unique (cell_psu ))) if cell_psu is not None else 0
1513-
1514- cell_resolved = full_resolved .subset_to_units (
1515- row_idx = pos_idx ,
1516- weights = cell_w ,
1517- strata = cell_strata ,
1518- psu = cell_psu ,
1519- fpc = cell_fpc ,
1520- n_strata = cell_n_strata ,
1521- n_psu = cell_n_psu ,
1522- )
1536+ cell_n = int (np .sum (cell_mask ))
1537+ cell_key_str = str (cell_key )
15231538
1524- # Cell-level statistics
1539+ # Cell-level statistics (Kish ESS is a property of the cell)
1540+ cell_w = full_resolved .weights [cell_mask ]
15251541 sum_w = float (np .sum (cell_w ))
15261542 sum_w2 = float (np .sum (cell_w ** 2 ))
15271543 cell_n_eff = (sum_w ** 2 / sum_w2 ) if sum_w2 > 0 else 0.0
@@ -1539,10 +1555,14 @@ def aggregate_survey(
15391555
15401556 cell_srs_fallback = False
15411557
1542- # Outcomes: mean + SE + n + precision
1558+ # Outcomes: mean + SE + n + precision (full-design domain estimation)
15431559 for var in outcome_cols :
1544- y = cell_df [var ].values .astype (np .float64 )
1545- y_bar , variance , n_valid , used_srs = _cell_mean_variance (y , cell_w , cell_resolved )
1560+ y_bar , variance , n_valid , used_srs = _cell_mean_variance (
1561+ y_arrays [var ],
1562+ full_resolved ,
1563+ cell_mask ,
1564+ min_n ,
1565+ )
15461566 se = float (np .sqrt (variance )) if not np .isnan (variance ) else np .nan
15471567
15481568 if used_srs :
@@ -1562,14 +1582,14 @@ def aggregate_survey(
15621582 row [f"{ var } _n" ] = n_valid
15631583 row [f"{ var } _precision" ] = precision
15641584
1565- # Covariates: mean only
1585+ # Covariates: design-weighted mean only
15661586 for var in cov_cols :
1567- y = cell_df [var ]. values . astype ( np . float64 )
1568- valid = ~ np .isnan (y )
1587+ y_cell = y_arrays [var ][ cell_mask ]
1588+ valid = ~ np .isnan (y_cell )
15691589 w_valid = cell_w * valid .astype (np .float64 )
15701590 sw = float (np .sum (w_valid ))
15711591 if sw > 0 :
1572- row [f"{ var } _mean" ] = float (np .sum (w_valid * np .where (valid , y , 0.0 )) / sw )
1592+ row [f"{ var } _mean" ] = float (np .sum (w_valid * np .where (valid , y_cell , 0.0 )) / sw )
15731593 else :
15741594 row [f"{ var } _mean" ] = np .nan
15751595
0 commit comments