Skip to content

Commit 8ce7c5d

Browse files
igerberclaude
andcommitted
Tighten validation and R parity test assertions
- Validate exact period-set equality (not just counts) for balanced panel - Reject non-finite outcomes (Inf) and covariates up front - R parity tests now assert GT vector lengths and (g,t) label identity before comparing ATT/SE values Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9a87f4e commit 8ce7c5d

2 files changed

Lines changed: 44 additions & 7 deletions

File tree

diff_diff/staggered_triple_diff.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,23 @@ def _validate_inputs(
624624
for col in [outcome, first_treat, eligibility]:
625625
if df[col].isna().any():
626626
raise ValueError(f"Column '{col}' contains missing values.")
627+
628+
# Reject non-finite outcomes (Inf/-Inf)
629+
if not np.all(np.isfinite(df[outcome])):
630+
raise ValueError(
631+
f"Column '{outcome}' contains non-finite values (Inf/-Inf). "
632+
"All outcome values must be finite."
633+
)
634+
635+
# Reject non-finite covariates
636+
if covariates:
637+
for cov in covariates:
638+
if df[cov].isna().any():
639+
raise ValueError(f"Covariate '{cov}' contains missing values.")
640+
if not np.all(np.isfinite(df[cov])):
641+
raise ValueError(
642+
f"Covariate '{cov}' contains non-finite values."
643+
)
627644
if df[eligibility].nunique() < 2:
628645
raise ValueError(
629646
"Need both eligible (Q=1) and ineligible (Q=0) units. "
@@ -638,16 +655,16 @@ def _validate_inputs(
638655
f"{int(dup.sum())} duplicates detected. Panel must have unique rows."
639656
)
640657

641-
# Check balanced panel — every unit observed in every period
642-
all_periods = df[time].unique()
643-
n_global_periods = len(all_periods)
644-
periods_per_unit = df.groupby(unit)[time].nunique()
645-
incomplete = periods_per_unit[periods_per_unit < n_global_periods]
646-
if len(incomplete) > 0:
658+
# Check balanced panel — every unit observed in exactly the global period set
659+
global_periods = set(df[time].unique())
660+
n_global_periods = len(global_periods)
661+
unit_period_sets = df.groupby(unit)[time].apply(set)
662+
mismatched = unit_period_sets[unit_period_sets != global_periods]
663+
if len(mismatched) > 0:
647664
raise ValueError(
648665
"Unbalanced panel detected. All units must be observed in "
649666
f"all {n_global_periods} periods. "
650-
f"Found {len(incomplete)} units with fewer periods."
667+
f"Found {len(mismatched)} units with different period sets."
651668
)
652669

653670
# Check time-invariant first_treat

tests/test_methodology_staggered_triple_diff.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,14 @@ def test_gt_att_matches_r(self, r_results, key):
152152
res = _run_python(data, r["est_method"], r["control_group"])
153153

154154
py_gt = sorted(res.group_time_effects.items())
155+
r_gt = list(zip(r["gt_groups"], r["gt_periods"]))
156+
assert len(py_gt) == len(r["gt_att"]), (
157+
f"{key}: Python has {len(py_gt)} GT cells, R has {len(r['gt_att'])}"
158+
)
155159
for i, ((g, t), eff) in enumerate(py_gt):
160+
assert (g, t) == (r_gt[i][0], r_gt[i][1]), (
161+
f"{key}: GT cell mismatch at index {i}: Python=({g},{t}), R={r_gt[i]}"
162+
)
156163
_assert_close(
157164
eff["effect"], r["gt_att"][i],
158165
ATT_RTOL, ATT_ATOL,
@@ -166,6 +173,9 @@ def test_gt_se_matches_r(self, r_results, key):
166173
res = _run_python(data, r["est_method"], r["control_group"])
167174

168175
py_gt = sorted(res.group_time_effects.items())
176+
assert len(py_gt) == len(r["gt_se"]), (
177+
f"{key}: Python has {len(py_gt)} GT cells, R has {len(r['gt_se'])}"
178+
)
169179
for i, ((g, t), eff) in enumerate(py_gt):
170180
_assert_close(
171181
eff["se"], r["gt_se"][i],
@@ -197,7 +207,14 @@ def test_gt_att_matches_r(self, r_results, key):
197207
res = _run_python(data, r["est_method"], r["control_group"])
198208

199209
py_gt = sorted(res.group_time_effects.items())
210+
r_gt = list(zip(r["gt_groups"], r["gt_periods"]))
211+
assert len(py_gt) == len(r["gt_att"]), (
212+
f"{key}: Python has {len(py_gt)} GT cells, R has {len(r['gt_att'])}"
213+
)
200214
for i, ((g, t), eff) in enumerate(py_gt):
215+
assert (g, t) == (r_gt[i][0], r_gt[i][1]), (
216+
f"{key}: GT cell mismatch at index {i}: Python=({g},{t}), R={r_gt[i]}"
217+
)
201218
_assert_close(
202219
eff["effect"], r["gt_att"][i],
203220
ATT_RTOL, ATT_ATOL,
@@ -211,6 +228,9 @@ def test_gt_se_matches_r(self, r_results, key):
211228
res = _run_python(data, r["est_method"], r["control_group"])
212229

213230
py_gt = sorted(res.group_time_effects.items())
231+
assert len(py_gt) == len(r["gt_se"]), (
232+
f"{key}: Python has {len(py_gt)} GT cells, R has {len(r['gt_se'])}"
233+
)
214234
for i, ((g, t), eff) in enumerate(py_gt):
215235
_assert_close(
216236
eff["se"], r["gt_se"][i],

0 commit comments

Comments
 (0)