Skip to content

Commit db170bd

Browse files
igerberclaude
andcommitted
Address PR #353 CI review round 4 (1 P1 + 2 P3)
P1 - stringified-label collision guard in stute_joint_pretest: The core indexed residuals_arrays / fitted_arrays by `str(k)` with no uniqueness check on the stringified keys. Two distinct raw keys whose str() forms collide (e.g. {1: ..., "1": ...} both stringify to "1", or custom objects with identical __str__) would silently overwrite one entry and then be double-counted in S_joint = sum(S_k) because the surviving horizon's statistic gets summed twice while n_horizons still reports K=2. That produces wrong methodology output with no diagnostic. Fix: compute the stringified labels once up front and reject any collision explicitly with a ValueError listing which raw keys collide to which stringified form. Centralizes the check before any residual/fitted array is dropped. Replaces the ad-hoc post-hoc re-keying with a reuse of the pre-computed collision-free list. P3 - dedupe staggered-filter UserWarning: `_validate_had_panel_event_study` already warns on the staggered auto-filter path; both joint-pretest wrappers and the event-study workflow were re-emitting the same information with a wrapper-prefixed message. Each staggered call therefore surfaced two warnings to the user. Removes the secondary emissions; wrappers now consume `_filter_info` silently. Existing tests still pass because the validator's own `"Staggered-timing panel detected"` message satisfies the regex matchers. P3 - collision regression test: new `TestStuteJointPretest::test_stringified_key_collision_raises` exercises (a) the int 1 + str "1" case and (b) a pair of custom objects with identical __str__ but distinct hash; both must raise `ValueError` with "collision after str" in the message. 125 tests pass (124 + 1 new R4 collision regression); black/ruff/ mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0040bad commit db170bd

2 files changed

Lines changed: 103 additions & 40 deletions

File tree

diff_diff/had_pretests.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,12 +1976,36 @@ def stute_joint_pretest(
19761976
if not np.all(np.isfinite(X)):
19771977
raise ValueError("design_matrix contains non-finite values (NaN/inf).")
19781978

1979-
horizon_labels = list(residuals_by_horizon.keys())
1980-
K = len(horizon_labels)
1979+
raw_horizon_labels = list(residuals_by_horizon.keys())
1980+
K = len(raw_horizon_labels)
1981+
1982+
# Stringified-label collision guard: distinct raw keys whose str()
1983+
# representations collide (e.g. {1: ..., "1": ..., 1.0: ...}) would
1984+
# overwrite each other in residuals_arrays / fitted_arrays, letting
1985+
# the surviving horizon be double-counted in S_joint = sum of S_k
1986+
# and leaving `n_horizons` inconsistent with the number of distinct
1987+
# diagnostic statistics. Reject explicitly rather than silently
1988+
# collapsing the test.
1989+
str_labels = [str(k) for k in raw_horizon_labels]
1990+
if len(set(str_labels)) != len(str_labels):
1991+
from collections import Counter
1992+
1993+
dup_strs = [s for s, c in Counter(str_labels).items() if c > 1]
1994+
collisions = {s: [k for k in raw_horizon_labels if str(k) == s] for s in dup_strs}
1995+
raise ValueError(
1996+
f"Horizon label collision after str() stringification: "
1997+
f"{collisions!r}. The joint Stute helpers index residuals "
1998+
f"and fitted values by str(label); distinct raw keys whose "
1999+
f"stringified form collides would silently overwrite each "
2000+
f"other and double-count the surviving horizon in S_joint. "
2001+
f"Use string-distinct horizon labels (e.g. 1997 and 1998 "
2002+
f'as int, or "1997" and "1998" as str; not both).'
2003+
)
2004+
19812005
any_nan = False
19822006
residuals_arrays: Dict[str, np.ndarray] = {}
19832007
fitted_arrays: Dict[str, np.ndarray] = {}
1984-
for k in horizon_labels:
2008+
for k in raw_horizon_labels:
19852009
eps_k = np.asarray(residuals_by_horizon[k], dtype=np.float64)
19862010
fit_k = np.asarray(fitted_by_horizon[k], dtype=np.float64)
19872011
if eps_k.shape != (G,) or fit_k.shape != (G,):
@@ -1997,8 +2021,9 @@ def stute_joint_pretest(
19972021

19982022
# Re-key to str labels consistently (wrappers already pass str; direct
19992023
# callers may pass int/object). String identity per the documented
2000-
# horizon_labels contract.
2001-
horizon_labels = [str(k) for k in horizon_labels]
2024+
# horizon_labels contract. The collision guard above ensures this
2025+
# stringification is injective on the provided keys.
2026+
horizon_labels = str_labels
20022027

20032028
if any_nan:
20042029
return StuteJointResult(
@@ -2242,7 +2267,7 @@ def joint_pretrends_test(
22422267
n_periods = int(data[time_col].nunique())
22432268
data_filtered: pd.DataFrame = data
22442269
if n_periods >= 3:
2245-
F_val, t_pre_list, _t_post_list, data_filtered, filter_info = (
2270+
F_val, t_pre_list, _t_post_list, data_filtered, _filter_info = (
22462271
_validate_had_panel_event_study(
22472272
data,
22482273
outcome_col=outcome_col,
@@ -2252,16 +2277,10 @@ def joint_pretrends_test(
22522277
first_treat_col=first_treat_col,
22532278
)
22542279
)
2255-
if filter_info is not None:
2256-
warnings.warn(
2257-
f"joint_pretrends_test: staggered panel auto-filtered to "
2258-
f"last cohort (F_last={filter_info['F_last']!r}, "
2259-
f"n_kept={filter_info['n_kept']}, "
2260-
f"n_dropped={filter_info['n_dropped']}). "
2261-
f"Paper Appendix B.2 prescription.",
2262-
UserWarning,
2263-
stacklevel=2,
2264-
)
2280+
# `_validate_had_panel_event_study` already emits its own
2281+
# `UserWarning` on the staggered-filter path; the wrapper
2282+
# consumes `_filter_info` silently to avoid duplicated console
2283+
# noise (R4 code-quality fix).
22652284
# Subset invariants: the caller's base_period and pre_periods
22662285
# must be pre-treatment periods under the validator's partition.
22672286
if base_period not in t_pre_list:
@@ -2429,7 +2448,7 @@ def joint_homogeneity_test(
24292448
n_periods = int(data[time_col].nunique())
24302449
data_filtered: pd.DataFrame = data
24312450
if n_periods >= 3:
2432-
F_val, t_pre_list, t_post_list, data_filtered, filter_info = (
2451+
F_val, t_pre_list, t_post_list, data_filtered, _filter_info = (
24332452
_validate_had_panel_event_study(
24342453
data,
24352454
outcome_col=outcome_col,
@@ -2439,16 +2458,10 @@ def joint_homogeneity_test(
24392458
first_treat_col=first_treat_col,
24402459
)
24412460
)
2442-
if filter_info is not None:
2443-
warnings.warn(
2444-
f"joint_homogeneity_test: staggered panel auto-filtered "
2445-
f"to last cohort (F_last={filter_info['F_last']!r}, "
2446-
f"n_kept={filter_info['n_kept']}, "
2447-
f"n_dropped={filter_info['n_dropped']}). "
2448-
f"Paper Appendix B.2 prescription.",
2449-
UserWarning,
2450-
stacklevel=2,
2451-
)
2461+
# `_validate_had_panel_event_study` already emits its own
2462+
# `UserWarning` on the staggered-filter path; the wrapper
2463+
# consumes `_filter_info` silently to avoid duplicated console
2464+
# noise (R4 code-quality fix).
24522465
if base_period not in t_pre_list:
24532466
raise ValueError(
24542467
f"base_period={base_period!r} is not in the validated "
@@ -2615,26 +2628,18 @@ def did_had_pretest_workflow(
26152628
)
26162629

26172630
if aggregate == "event_study":
2618-
F, t_pre_list, t_post_list, data_filtered, filter_info = _validate_multi_period_panel(
2631+
F, t_pre_list, t_post_list, data_filtered, _filter_info = _validate_multi_period_panel(
26192632
data,
26202633
outcome_col=outcome_col,
26212634
dose_col=dose_col,
26222635
time_col=time_col,
26232636
unit_col=unit_col,
26242637
first_treat_col=first_treat_col,
26252638
)
2626-
if filter_info is not None:
2627-
warnings.warn(
2628-
f"HAD event-study pre-test: staggered panel auto-"
2629-
f"filtered to last cohort "
2630-
f"(F_last={filter_info['F_last']!r}, "
2631-
f"n_kept={filter_info['n_kept']}, "
2632-
f"n_dropped={filter_info['n_dropped']}, "
2633-
f"dropped_cohorts={filter_info['dropped_cohorts']}). "
2634-
f"Paper Appendix B.2 prescription.",
2635-
UserWarning,
2636-
stacklevel=2,
2637-
)
2639+
# `_validate_multi_period_panel` delegates to
2640+
# `_validate_had_panel_event_study`, which already emits its own
2641+
# `UserWarning` on the staggered-filter path; we do NOT warn a
2642+
# second time here (R4 code-quality fix - single emission point).
26382643

26392644
# Base period for both joint tests is the last pre-period
26402645
# (paper convention: anchor at F-1 under natural time order).

tests/test_had_pretests.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,64 @@ def test_constant_d_returns_nan_with_warning(self):
15751575
assert set(result.per_horizon_stats.keys()) == set(resid.keys())
15761576
assert all(np.isnan(v) for v in result.per_horizon_stats.values())
15771577

1578+
def test_stringified_key_collision_raises(self):
1579+
"""R4 P1 regression: two raw keys whose str() representations
1580+
collide (e.g. int 1 and str '1', or int 1 and float 1.0) must
1581+
raise explicitly rather than silently overwrite one horizon in
1582+
the internal residuals_arrays map and double-count the survivor
1583+
in the sum-of-CvMs S_joint."""
1584+
G = 20
1585+
rng = np.random.default_rng(701)
1586+
d = rng.uniform(0.0, 1.0, G)
1587+
# int / str collision: str(1) == "1"
1588+
resid_int_str_collision = {
1589+
1: rng.normal(0.0, 1.0, G),
1590+
"1": rng.normal(0.0, 1.0, G),
1591+
}
1592+
fit_int_str_collision = {1: np.zeros(G), "1": np.zeros(G)}
1593+
with pytest.raises(ValueError, match="collision after str"):
1594+
stute_joint_pretest(
1595+
residuals_by_horizon=resid_int_str_collision,
1596+
fitted_by_horizon=fit_int_str_collision,
1597+
doses=d,
1598+
design_matrix=np.ones((G, 1)),
1599+
n_bootstrap=199,
1600+
seed=0,
1601+
)
1602+
1603+
# int / float collision: str(1) == "1" but str(1.0) == "1.0"
1604+
# so these actually don't collide. Test a real collision case:
1605+
# two different string representations of the same label.
1606+
# Python: str(True) == "True"; bool(1) == True but that's the
1607+
# same key. Use: str(None) == "None" collides if passed twice,
1608+
# but keys must be unique per dict. Safer: two equal-after-str
1609+
# object keys that were distinct before str conversion.
1610+
class _WeirdLabel:
1611+
def __init__(self, s):
1612+
self._s = s
1613+
1614+
def __str__(self):
1615+
return self._s
1616+
1617+
def __hash__(self):
1618+
return hash((id(self), self._s))
1619+
1620+
a = _WeirdLabel("horizon-1")
1621+
b = _WeirdLabel("horizon-1") # same str, different object
1622+
assert a is not b
1623+
assert str(a) == str(b)
1624+
resid_obj_collision = {a: rng.normal(0.0, 1.0, G), b: rng.normal(0.0, 1.0, G)}
1625+
fit_obj_collision = {a: np.zeros(G), b: np.zeros(G)}
1626+
with pytest.raises(ValueError, match="collision after str"):
1627+
stute_joint_pretest(
1628+
residuals_by_horizon=resid_obj_collision,
1629+
fitted_by_horizon=fit_obj_collision,
1630+
doses=d,
1631+
design_matrix=np.ones((G, 1)),
1632+
n_bootstrap=199,
1633+
seed=0,
1634+
)
1635+
15781636

15791637
class TestJointPretrendsTest:
15801638
"""Tests for :func:`joint_pretrends_test` data-in wrapper."""

0 commit comments

Comments
 (0)