Skip to content

Commit b237022

Browse files
igerberclaude
andcommitted
Address PR #376 R9 (2 P3)
R9 P3 #1 (helper error message canonical-kwarg consistency): `_resolve_pretest_unit_weights`'s TypeError on non-`SurveyDesign`-like input still said `survey=` must be a SurveyDesign — but on the data-in wrappers (workflow / joint_pretrends_test / joint_homogeneity_test) the canonical kwarg is now `survey_design=`. Updated the message to name `survey_design=` (with `survey=` flagged as the deprecated alias) and to point pre-resolved-design users to the array-in pretest helpers, mirroring HAD.fit's data-in guard. R9 P3 #2 (legacy-vs-canonical parity coverage on data-in pretests): Added 3 parity tests (test_legacy_alias_parity_survey on joint_pretrends_test + joint_homogeneity_test, plus test_legacy_alias_parity_survey_overall on did_had_pretest_workflow overall path). Locks the rebinding contract on the data-in surfaces that previously only had smoke / warning / mutex coverage. 558 tests pass (was 555 + 3 new R9 P3 parity tests). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9c0d742 commit b237022

2 files changed

Lines changed: 118 additions & 2 deletions

File tree

diff_diff/had_pretests.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3262,9 +3262,22 @@ def _resolve_pretest_unit_weights(
32623262
return weights_unit, None
32633263
# survey is not None
32643264
if not hasattr(survey, "resolve"):
3265+
# PR #376 R9 P3: error message names the canonical kwarg
3266+
# `survey_design=` (with the deprecated `survey=` alias mentioned
3267+
# for back-compat), and points pre-resolved-design users to the
3268+
# array-in pretest helpers where ResolvedSurveyDesign /
3269+
# make_pweight_design(arr) belong.
32653270
raise TypeError(
3266-
f"{caller_name}: survey= must be a SurveyDesign instance "
3267-
f"(with .resolve()); got {type(survey).__name__}."
3271+
f"{caller_name}: `survey_design=` (or the deprecated `survey=` "
3272+
f"alias) accepts a SurveyDesign instance (column-referencing, "
3273+
f"gets `.resolve(data)`'d at fit time) on data-in surfaces; "
3274+
f"got {type(survey).__name__} (no `.resolve()` method). "
3275+
"If you have a pre-resolved ResolvedSurveyDesign or used "
3276+
"`make_pweight_design(arr)`, that pattern is for the array-in "
3277+
"pretest helpers (`stute_test`, `yatchew_hr_test`, "
3278+
"`stute_joint_pretest`). On data-in surfaces, add the weights "
3279+
"as a column on `data` and pass "
3280+
"`survey_design=SurveyDesign(weights='col_name', ...)`."
32683281
)
32693282
resolved_full = survey.resolve(data)
32703283
if getattr(resolved_full, "replicate_weights", None) is not None:

tests/test_had_dual_knob_deprecation.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,41 @@ def test_three_way_mutex_design_plus_survey(self, event_study_panel):
605605
seed=0,
606606
)
607607

608+
def test_legacy_alias_parity_survey(self, event_study_panel):
609+
"""PR #376 R9 P3: deprecated `survey=SurveyDesign(...)` ≡ canonical
610+
`survey_design=SurveyDesign(...)` on joint_pretrends_test (locks
611+
rebinding parity)."""
612+
df = event_study_panel
613+
sd = SurveyDesign(weights="w")
614+
with warnings.catch_warnings():
615+
warnings.simplefilter("ignore", DeprecationWarning)
616+
r_legacy = joint_pretrends_test(
617+
df,
618+
"y",
619+
"d",
620+
"time",
621+
"unit",
622+
pre_periods=[0],
623+
base_period=1,
624+
survey=sd,
625+
n_bootstrap=199,
626+
seed=0,
627+
)
628+
r_new = joint_pretrends_test(
629+
df,
630+
"y",
631+
"d",
632+
"time",
633+
"unit",
634+
pre_periods=[0],
635+
base_period=1,
636+
survey_design=sd,
637+
n_bootstrap=199,
638+
seed=0,
639+
)
640+
assert r_legacy.cvm_stat_joint == r_new.cvm_stat_joint
641+
assert r_legacy.p_value == r_new.p_value
642+
608643

609644
class TestJointHomogeneityTestDeprecation:
610645
def test_survey_design_kwarg_smoke(self, event_study_panel):
@@ -656,6 +691,40 @@ def test_survey_emits_deprecation_warning(self, event_study_panel):
656691
seed=0,
657692
)
658693

694+
def test_legacy_alias_parity_survey(self, event_study_panel):
695+
"""PR #376 R9 P3: deprecated `survey=SurveyDesign(...)` ≡ canonical
696+
`survey_design=SurveyDesign(...)` on joint_homogeneity_test."""
697+
df = event_study_panel
698+
sd = SurveyDesign(weights="w")
699+
with warnings.catch_warnings():
700+
warnings.simplefilter("ignore", DeprecationWarning)
701+
r_legacy = joint_homogeneity_test(
702+
df,
703+
"y",
704+
"d",
705+
"time",
706+
"unit",
707+
post_periods=[2, 3],
708+
base_period=1,
709+
survey=sd,
710+
n_bootstrap=199,
711+
seed=0,
712+
)
713+
r_new = joint_homogeneity_test(
714+
df,
715+
"y",
716+
"d",
717+
"time",
718+
"unit",
719+
post_periods=[2, 3],
720+
base_period=1,
721+
survey_design=sd,
722+
n_bootstrap=199,
723+
seed=0,
724+
)
725+
assert r_legacy.cvm_stat_joint == r_new.cvm_stat_joint
726+
assert r_legacy.p_value == r_new.p_value
727+
659728

660729
class TestHADFitDeprecation:
661730
def test_survey_design_kwarg_smoke(self, two_period_panel):
@@ -864,6 +933,40 @@ def test_three_way_mutex_all_three(self, two_period_panel):
864933
seed=0,
865934
)
866935

936+
def test_legacy_alias_parity_survey_overall(self, two_period_panel):
937+
"""PR #376 R9 P3: deprecated `survey=SurveyDesign(...)` ≡ canonical
938+
`survey_design=SurveyDesign(...)` on
939+
did_had_pretest_workflow(aggregate='overall'). Locks rebinding
940+
parity on the workflow's overall-path data-in surface."""
941+
df = two_period_panel
942+
sd = SurveyDesign(weights="w")
943+
with warnings.catch_warnings():
944+
warnings.simplefilter("ignore", UserWarning) # QUG-skip warning
945+
warnings.simplefilter("ignore", DeprecationWarning)
946+
r_legacy = did_had_pretest_workflow(
947+
df,
948+
"y",
949+
"d",
950+
"time",
951+
"unit",
952+
survey=sd,
953+
n_bootstrap=199,
954+
seed=0,
955+
)
956+
r_new = did_had_pretest_workflow(
957+
df,
958+
"y",
959+
"d",
960+
"time",
961+
"unit",
962+
survey_design=sd,
963+
n_bootstrap=199,
964+
seed=0,
965+
)
966+
assert r_legacy.stute.cvm_stat == r_new.stute.cvm_stat
967+
assert r_legacy.stute.p_value == r_new.stute.p_value
968+
assert r_legacy.yatchew.t_stat_hr == r_new.yatchew.t_stat_hr
969+
867970

868971
# =============================================================================
869972
# 3. PR #376 R2 P1: extended dispatch-matrix coverage on the new front door

0 commit comments

Comments
 (0)