Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 2 additions & 76 deletions diff_diff/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,83 +354,9 @@ def _collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units):
Survey design columns are constant within units (validated upstream).
This extracts one row per unit, aligned to ``all_units`` ordering.
"""
from diff_diff.survey import ResolvedSurveyDesign

n_units = len(all_units)
# Use groupby().first() to get one value per unit, then reindex
unit_groups = df.groupby(unit_col)

weights_unit = (
pd.Series(resolved_survey.weights, index=df.index)
.groupby(df[unit_col])
.first()
.reindex(all_units)
.values
)

strata_unit = None
if resolved_survey.strata is not None:
strata_unit = (
pd.Series(resolved_survey.strata, index=df.index)
.groupby(df[unit_col])
.first()
.reindex(all_units)
.values
)

psu_unit = None
if resolved_survey.psu is not None:
psu_unit = (
pd.Series(resolved_survey.psu, index=df.index)
.groupby(df[unit_col])
.first()
.reindex(all_units)
.values
)
from diff_diff.survey import collapse_survey_to_unit_level

fpc_unit = None
if resolved_survey.fpc is not None:
fpc_unit = (
pd.Series(resolved_survey.fpc, index=df.index)
.groupby(df[unit_col])
.first()
.reindex(all_units)
.values
)

# Collapse replicate weights to unit level (same groupby pattern)
rep_weights_unit = None
if resolved_survey.replicate_weights is not None:
R = resolved_survey.replicate_weights.shape[1]
rep_weights_unit = np.zeros((n_units, R))
for r in range(R):
rep_weights_unit[:, r] = (
pd.Series(resolved_survey.replicate_weights[:, r], index=df.index)
.groupby(df[unit_col])
.first()
.reindex(all_units)
.values
)

return ResolvedSurveyDesign(
weights=weights_unit.astype(np.float64),
weight_type=resolved_survey.weight_type,
strata=strata_unit,
psu=psu_unit,
fpc=fpc_unit,
n_strata=resolved_survey.n_strata,
n_psu=resolved_survey.n_psu,
lonely_psu=resolved_survey.lonely_psu,
replicate_weights=rep_weights_unit,
replicate_method=resolved_survey.replicate_method,
fay_rho=resolved_survey.fay_rho,
n_replicates=resolved_survey.n_replicates,
replicate_strata=resolved_survey.replicate_strata,
combined_weights=resolved_survey.combined_weights,
replicate_scale=resolved_survey.replicate_scale,
replicate_rscales=resolved_survey.replicate_rscales,
mse=resolved_survey.mse,
)
return collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units)

def _precompute_structures(
self,
Expand Down
Loading
Loading