Skip to content

Commit 9a87f4e

Browse files
igerberclaude
andcommitted
Add regression tests for warning paths and PS cache with collinearity
- Test base period outside panel: warns and produces finite results - Test empty DDD subgroup: warns with cell details, still estimates - Test collinear covariates with PS cache reuse: all (g,t) finite Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9bc0341 commit 9a87f4e

1 file changed

Lines changed: 68 additions & 0 deletions

File tree

tests/test_staggered_triple_diff.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,71 @@ def test_to_dict(self, simple_data):
403403
assert "overall_att" in d
404404
assert "n_obs" in d
405405
assert "estimation_method" in d
406+
407+
408+
# ---------------------------------------------------------------------------
409+
# Regression tests for specific bug fixes
410+
# ---------------------------------------------------------------------------
411+
412+
class TestStaggeredTripleDiffRegressions:
413+
def test_base_period_outside_panel_warns(self):
414+
"""Cohort with base period before observed panel should warn, not crash."""
415+
# Cohort g=2 with anticipation=1 needs base_period = g-1-1 = 0,
416+
# but periods start at 1. Should warn and skip that cell.
417+
data = generate_staggered_ddd_data(
418+
n_units=100, n_periods=4, cohort_periods=[2, 4],
419+
seed=77,
420+
)
421+
est = StaggeredTripleDifference(anticipation=1)
422+
import warnings as _w
423+
with _w.catch_warnings(record=True) as caught:
424+
_w.simplefilter("always")
425+
res = est.fit(data, "outcome", "unit", "period",
426+
"first_treat", "eligibility")
427+
base_period_warnings = [
428+
w for w in caught if "outside the observed panel" in str(w.message)
429+
]
430+
assert len(base_period_warnings) > 0, "Expected warning about base period"
431+
assert np.isfinite(res.overall_att)
432+
433+
def test_empty_subgroup_warns(self):
434+
"""Data where one (S,Q) cell is empty should warn, not crash."""
435+
data = generate_staggered_ddd_data(
436+
n_units=100, cohort_periods=[4, 6], seed=88,
437+
)
438+
# Remove all ineligible units from cohort 6 to make (S=6,Q=0) empty
439+
mask = ~((data["first_treat"] == 6) & (data["eligibility"] == 0))
440+
data = data[mask].reset_index(drop=True)
441+
est = StaggeredTripleDifference()
442+
import warnings as _w
443+
with _w.catch_warnings(record=True) as caught:
444+
_w.simplefilter("always")
445+
res = est.fit(data, "outcome", "unit", "period",
446+
"first_treat", "eligibility")
447+
subgroup_warnings = [
448+
w for w in caught if "Empty subgroup" in str(w.message)
449+
]
450+
assert len(subgroup_warnings) > 0, "Expected warning about empty subgroup"
451+
assert np.isfinite(res.overall_att)
452+
453+
def test_collinear_covariates_cached_ps_finite(self):
454+
"""Collinear covariates with PS cache reuse should produce finite results."""
455+
data = generate_staggered_ddd_data(
456+
n_units=200, treatment_effect=3.0,
457+
add_covariates=True, seed=55,
458+
)
459+
# Add a perfectly collinear covariate (x3 = 2*x1)
460+
data["x3"] = 2.0 * data["x1"]
461+
est = StaggeredTripleDifference(
462+
estimation_method="dr", rank_deficient_action="warn",
463+
)
464+
import warnings as _w
465+
with _w.catch_warnings(record=True):
466+
_w.simplefilter("always")
467+
res = est.fit(data, "outcome", "unit", "period",
468+
"first_treat", "eligibility",
469+
covariates=["x1", "x2", "x3"])
470+
# All group-time effects should be finite despite collinearity
471+
for (g, t), eff in res.group_time_effects.items():
472+
assert np.isfinite(eff["effect"]), f"Non-finite ATT at (g={g},t={t})"
473+
assert np.isfinite(eff["se"]), f"Non-finite SE at (g={g},t={t})"

0 commit comments

Comments
 (0)