Skip to content

Commit 85c3710

Browse files
igerberclaude
andcommitted
Remove CS cluster-as-PSU injection, add CS aggregate survey tests
CallawaySantAnna uses weights-only survey (strata/PSU/FPC rejected), so cluster-as-PSU injection is unnecessary and misleading. The user's cluster= parameter is handled by the existing non-survey clustering path. Removed _inject_cluster_as_psu call and unused imports. Added aggregate="group" and aggregate="all" survey tests for CS. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 133df19 commit 85c3710

2 files changed

Lines changed: 37 additions & 19 deletions

File tree

diff_diff/staggered.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,8 +1180,6 @@ def fit(
11801180

11811181
# Resolve survey design if provided
11821182
from diff_diff.survey import (
1183-
_inject_cluster_as_psu,
1184-
_resolve_effective_cluster,
11851183
_resolve_survey_for_fit,
11861184
_validate_unit_constant_survey,
11871185
)
@@ -1286,25 +1284,16 @@ def fit(
12861284
"cohorts when there are no never-treated units."
12871285
)
12881286

1289-
# Resolve effective cluster and inject cluster-as-PSU for survey variance
1290-
if resolved_survey is not None:
1291-
cluster_var = self.cluster if self.cluster is not None else unit
1292-
cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None
1293-
effective_cluster_ids = _resolve_effective_cluster(
1294-
resolved_survey,
1295-
cluster_ids_raw,
1296-
cluster_var if self.cluster is not None else None,
1297-
)
1298-
resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
1299-
# Recompute metadata after PSU injection
1300-
if resolved_survey.psu is not None and survey_metadata is not None:
1287+
# Note: CallawaySantAnna uses weights-only survey (strata/PSU/FPC
1288+
# rejected above). We do NOT inject cluster-as-PSU here because CS
1289+
# per-cell SEs use IF-based variance, not TSL. The user's cluster=
1290+
# parameter is handled by the existing non-survey clustering path.
1291+
if resolved_survey is not None and survey_metadata is not None:
1292+
# Just recompute metadata with the resolved design (no PSU injection)
1293+
if survey_design.weights:
13011294
from diff_diff.survey import compute_survey_metadata
13021295

1303-
raw_w = (
1304-
data[survey_design.weights].values.astype(np.float64)
1305-
if survey_design.weights
1306-
else np.ones(len(data), dtype=np.float64)
1307-
)
1296+
raw_w = data[survey_design.weights].values.astype(np.float64)
13081297
survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
13091298

13101299
# Pre-compute data structures for efficient ATT(g,t) computation

tests/test_survey_phase4.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,35 @@ def test_strata_psu_fpc_raises(self, staggered_survey_data):
829829
survey_design=sd_full,
830830
)
831831

832+
def test_aggregate_group_with_survey(self, staggered_survey_data, survey_design_weights_only):
833+
"""aggregate='group' works with weights-only survey design."""
834+
result = CallawaySantAnna(estimation_method="reg").fit(
835+
staggered_survey_data,
836+
"outcome",
837+
"unit",
838+
"period",
839+
"first_treat",
840+
aggregate="group",
841+
survey_design=survey_design_weights_only,
842+
)
843+
assert result.group_effects is not None
844+
assert len(result.group_effects) > 0
845+
846+
def test_aggregate_all_with_survey(self, staggered_survey_data, survey_design_weights_only):
847+
"""aggregate='all' works with weights-only survey design."""
848+
result = CallawaySantAnna(estimation_method="reg").fit(
849+
staggered_survey_data,
850+
"outcome",
851+
"unit",
852+
"period",
853+
"first_treat",
854+
aggregate="all",
855+
survey_design=survey_design_weights_only,
856+
)
857+
assert np.isfinite(result.overall_att)
858+
assert result.event_study_effects is not None
859+
assert result.group_effects is not None
860+
832861
def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only):
833862
"""Bootstrap + survey should raise NotImplementedError."""
834863
with pytest.raises(NotImplementedError, match="[Bb]ootstrap"):

0 commit comments

Comments
 (0)