Skip to content

Commit 5d62910

Browse files
igerberclaude
andcommitted
Add survey design support to StaggeredTripleDifference estimator
Thread survey weights through all three pairwise DiD comparisons (propensity scores, outcome regression, Riesz representers) with design-based variance at aggregation via CallawaySantAnna mixin infrastructure. Extract collapse_survey_to_unit_level to survey.py for reuse. Full test coverage across estimation methods, survey designs, and aggregation modes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 274ade9 commit 5d62910

6 files changed

Lines changed: 1038 additions & 267 deletions

File tree

diff_diff/staggered.py

Lines changed: 2 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -354,83 +354,9 @@ def _collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units):
354354
Survey design columns are constant within units (validated upstream).
355355
This extracts one row per unit, aligned to ``all_units`` ordering.
356356
"""
357-
from diff_diff.survey import ResolvedSurveyDesign
358-
359-
n_units = len(all_units)
360-
# Use groupby().first() to get one value per unit, then reindex
361-
unit_groups = df.groupby(unit_col)
362-
363-
weights_unit = (
364-
pd.Series(resolved_survey.weights, index=df.index)
365-
.groupby(df[unit_col])
366-
.first()
367-
.reindex(all_units)
368-
.values
369-
)
370-
371-
strata_unit = None
372-
if resolved_survey.strata is not None:
373-
strata_unit = (
374-
pd.Series(resolved_survey.strata, index=df.index)
375-
.groupby(df[unit_col])
376-
.first()
377-
.reindex(all_units)
378-
.values
379-
)
380-
381-
psu_unit = None
382-
if resolved_survey.psu is not None:
383-
psu_unit = (
384-
pd.Series(resolved_survey.psu, index=df.index)
385-
.groupby(df[unit_col])
386-
.first()
387-
.reindex(all_units)
388-
.values
389-
)
357+
from diff_diff.survey import collapse_survey_to_unit_level
390358

391-
fpc_unit = None
392-
if resolved_survey.fpc is not None:
393-
fpc_unit = (
394-
pd.Series(resolved_survey.fpc, index=df.index)
395-
.groupby(df[unit_col])
396-
.first()
397-
.reindex(all_units)
398-
.values
399-
)
400-
401-
# Collapse replicate weights to unit level (same groupby pattern)
402-
rep_weights_unit = None
403-
if resolved_survey.replicate_weights is not None:
404-
R = resolved_survey.replicate_weights.shape[1]
405-
rep_weights_unit = np.zeros((n_units, R))
406-
for r in range(R):
407-
rep_weights_unit[:, r] = (
408-
pd.Series(resolved_survey.replicate_weights[:, r], index=df.index)
409-
.groupby(df[unit_col])
410-
.first()
411-
.reindex(all_units)
412-
.values
413-
)
414-
415-
return ResolvedSurveyDesign(
416-
weights=weights_unit.astype(np.float64),
417-
weight_type=resolved_survey.weight_type,
418-
strata=strata_unit,
419-
psu=psu_unit,
420-
fpc=fpc_unit,
421-
n_strata=resolved_survey.n_strata,
422-
n_psu=resolved_survey.n_psu,
423-
lonely_psu=resolved_survey.lonely_psu,
424-
replicate_weights=rep_weights_unit,
425-
replicate_method=resolved_survey.replicate_method,
426-
fay_rho=resolved_survey.fay_rho,
427-
n_replicates=resolved_survey.n_replicates,
428-
replicate_strata=resolved_survey.replicate_strata,
429-
combined_weights=resolved_survey.combined_weights,
430-
replicate_scale=resolved_survey.replicate_scale,
431-
replicate_rscales=resolved_survey.replicate_rscales,
432-
mse=resolved_survey.mse,
433-
)
359+
return collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units)
434360

435361
def _precompute_structures(
436362
self,

0 commit comments

Comments
 (0)