Skip to content

Commit b2f4dc4

Browse files
igerberclaude
andcommitted
fix: check realized never-treated count, not just fraction
Small never_treated_frac values (e.g., 0.01 with n_units=50) floor to zero never-treated units via int(), silently breaking the conditional_pt contract. Now checks int(n_units * never_treated_frac) >= 1 and reports the realized count in the error message. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c86c0dc commit b2f4dc4

3 files changed

Lines changed: 20 additions & 11 deletions

File tree

diff_diff/prep_dgp.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,8 +1311,8 @@ def generate_survey_did_data(
13111311
Conditional on x1, trends remain parallel (conditional PT holds).
13121312
DR/IPW estimators with covariates recover truth; no-covariate
13131313
estimators are biased. Uses normalized time (t/n_periods) for
1314-
scale independence. Requires ``add_covariates=True`` and
1315-
``never_treated_frac > 0`` (the x1 mean shift only differentiates
1314+
scale independence. Requires ``add_covariates=True`` and at least
1315+
one never-treated unit (the x1 mean shift only differentiates
13161316
ever-treated from never-treated units).
13171317
13181318
.. note:: When used with ``icc``, the ICC calibration is approximate
@@ -1437,12 +1437,14 @@ def generate_survey_did_data(
14371437
)
14381438
if conditional_pt != 0.0 and not add_covariates:
14391439
raise ValueError("conditional_pt requires add_covariates=True")
1440-
if conditional_pt != 0.0 and never_treated_frac == 0.0:
1440+
if conditional_pt != 0.0 and int(n_units * never_treated_frac) < 1:
14411441
raise ValueError(
1442-
"conditional_pt requires never_treated_frac > 0. The x1 mean shift "
1443-
"applies to all ever-treated units; without a never-treated group, "
1444-
"treated and control units share the same x1 distribution and "
1445-
"unconditional parallel trends are not violated."
1442+
"conditional_pt requires at least one never-treated unit "
1443+
f"(n_units={n_units}, never_treated_frac={never_treated_frac} "
1444+
f"yields {int(n_units * never_treated_frac)} never-treated). "
1445+
"The x1 mean shift applies to all ever-treated units; without a "
1446+
"never-treated group, treated and control units share the same x1 "
1447+
"distribution and unconditional parallel trends are not violated."
14461448
)
14471449

14481450
# --- ICC -> psu_re_sd resolution ---

docs/methodology/REGISTRY.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,8 +2708,8 @@ The 8-step workflow in `docs/llms-practitioner.txt` is adapted from Baker et al.
27082708
E[x1 | treated] != E[x1 | control], the average time trend differs by group
27092709
(unconditional PT fails). Conditional on x1, trends are identical (conditional
27102710
PT holds). DR/IPW estimators with x1 as covariate recover the true ATT.
2711-
Requires `never_treated_frac > 0`; rejected otherwise because the x1 mean
2712-
shift only differentiates ever-treated from never-treated units.
2711+
Requires at least one never-treated unit (rejected otherwise because the x1
2712+
mean shift only differentiates ever-treated from never-treated units).
27132713
- **Note:** When `conditional_pt != 0` is combined with `icc`, the ICC
27142714
calibration is approximate. The x1 mean shift creates a mixture distribution
27152715
with marginal Var(x1) = 1 + p_treated * (1 - p_treated) > 1, slightly

tests/test_prep.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,14 +1969,21 @@ def test_te_covariate_interaction_validation(self):
19691969
# --- conditional_pt parameter tests ---
19701970

19711971
def test_conditional_pt_requires_never_treated(self):
1972-
"""conditional_pt requires never_treated_frac > 0."""
1972+
"""conditional_pt requires at least one never-treated unit."""
19731973
from diff_diff.prep_dgp import generate_survey_did_data
19741974

1975-
with pytest.raises(ValueError, match="conditional_pt requires never_treated_frac"):
1975+
# Exact zero fraction
1976+
with pytest.raises(ValueError, match="conditional_pt requires at least one"):
19761977
generate_survey_did_data(
19771978
add_covariates=True, conditional_pt=0.3,
19781979
never_treated_frac=0.0, seed=42,
19791980
)
1981+
# Small fraction that floors to zero never-treated units
1982+
with pytest.raises(ValueError, match="conditional_pt requires at least one"):
1983+
generate_survey_did_data(
1984+
n_units=50, add_covariates=True, conditional_pt=0.3,
1985+
never_treated_frac=0.01, seed=42,
1986+
)
19801987

19811988
def test_conditional_pt_requires_covariates(self):
19821989
"""conditional_pt requires add_covariates=True."""

0 commit comments

Comments
 (0)