Skip to content

Commit 2ba1010

Browse files
igerberclaude
andcommitted
Address PR #356 CI review round 13 (1 P1 guide + code)
Bool-dtype treatment columns are now classified the same way as numeric {0, 1} rather than as "categorical". The library's binary estimators validate value support via `validate_binary` (utils.py: 49-67), which accepts bool because True/False coerce to 1/0 numerically. Classifying bool as categorical silently routed valid binary DiD panels away from the supported estimator set. Changes: - _classify_treatment() no longer early-returns "categorical" for bool dtype. The downstream absorbing/non-absorbing logic handles bool by casting to int before np.diff (raw bool diff is XOR, which would mask a True -> False transition). - treatment_varies_within_unit now includes bool-dtype columns (was previously hardcoded False for bool). - Guide §2 removes the "bool = categorical" rule and adds an explicit "bool is binary" note with a pointer to validate_binary as the reason. - profile_panel() docstring mirrors the same update. Tests: - test_bool_dtype_treatment_is_binary_absorbing: staggered-style bool panel with never-treated cohort -> binary_absorbing, correct has_never_treated / treatment_varies_within_unit / cohort_sizes. - test_bool_dtype_non_absorbing: reversible False -> True -> False bool panel -> binary_non_absorbing. Guards the int-cast before np.diff so future refactors don't regress to bool XOR semantics. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 610b8aa commit 2ba1010

3 files changed

Lines changed: 67 additions & 15 deletions

File tree

diff_diff/guides/llms-autonomous.txt

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,16 @@ view. Every field below appears as a top-level key in that dict.
9090
two-valued numeric column whose values are not in {0, 1} (e.g.,
9191
a dose, a discrete-integer partial-adoption score). Use
9292
`ContinuousDiD` or `HeterogeneousAdoptionDiD`.
93-
- `"categorical"`: non-numeric dtype (object / category), a bool
94-
dtype column, or a column that is entirely NaN. Often indicates
95-
a treatment arm. Encode each arm as a binary indicator and fit
96-
separately, or use a multi-treatment workflow outside the
97-
current estimator suite.
93+
- `"categorical"`: non-numeric dtype (object / category), or a
94+
column that is entirely NaN. Often indicates a treatment arm.
95+
Encode each arm as a binary indicator and fit separately, or
96+
use a multi-treatment workflow outside the current estimator
97+
suite.
98+
99+
Bool-dtype treatment columns (`True` / `False`) are classified the
100+
same way as numeric `{0, 1}`: the library's binary estimators
101+
validate on value support rather than dtype, so `True` and `False`
102+
behave like `1` and `0` for absorbing / non-absorbing classification.
98103
- **`is_staggered: bool`** - true iff treatment is `binary_absorbing` and
99104
at least two distinct first-treatment periods are observed. Drives the
100105
choice between classic DiD/TWFE and staggered-robust estimators.

diff_diff/profile.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,14 @@ def profile_panel(
185185
- ``"continuous"``: numeric treatment with more than two distinct
186186
values, or a 2-valued numeric whose values are not in
187187
:math:`\\{0, 1\\}` (matches the ``ContinuousDiD`` convention).
188-
- ``"categorical"``: non-numeric dtype (object / category), a
189-
boolean-dtype column, or a column that is entirely NaN.
188+
- ``"categorical"``: non-numeric dtype (object / category) or a
189+
column that is entirely NaN.
190190
191-
Boolean-dtype columns are intentionally classified as
192-
``"categorical"``; cast to ``int`` if you want binary-treatment
193-
profiling.
191+
Bool-dtype columns (``True`` / ``False``) are classified the same
192+
way as numeric ``{0, 1}``: the library's binary estimators validate
193+
on value support via :func:`diff_diff.utils.validate_binary`, so
194+
``True`` / ``False`` behave like ``1`` / ``0`` for absorbing /
195+
non-absorbing classification.
194196
195197
``has_never_treated`` is computed across both binary and
196198
continuous numeric treatment types: some unit has ``treatment ==
@@ -255,9 +257,7 @@ def profile_panel(
255257
last_tp,
256258
) = _classify_treatment(df, unit=unit, time=time, treatment=treatment)
257259

258-
if pd.api.types.is_numeric_dtype(df[treatment]) and not pd.api.types.is_bool_dtype(
259-
df[treatment]
260-
):
260+
if pd.api.types.is_numeric_dtype(df[treatment]) or pd.api.types.is_bool_dtype(df[treatment]):
261261
per_unit_distinct = df.groupby(unit)[treatment].nunique(dropna=True)
262262
treatment_varies_within_unit = bool((per_unit_distinct > 1).any())
263263
else:
@@ -352,7 +352,13 @@ def _classify_treatment(
352352
is_numeric = pd.api.types.is_numeric_dtype(col)
353353
is_bool = pd.api.types.is_bool_dtype(col)
354354

355-
if (not is_numeric) or is_bool:
355+
# Bool-dtype treatment columns are treated as binary 0/1 inputs.
356+
# The library's binary estimators validate value support via
357+
# `validate_binary`, which accepts bool because True/False coerce
358+
# to 1/0 numerically. Classifying bool columns as "categorical"
359+
# here would route a valid binary design away from the supported
360+
# estimator set.
361+
if (not is_numeric) and (not is_bool):
356362
return ("categorical", False, {}, False, False, None, None)
357363

358364
distinct = col.dropna().unique()
@@ -400,7 +406,10 @@ def _classify_treatment(
400406
for _, group in sorted_df.groupby(unit, sort=False):
401407
vals = group[treatment].to_numpy()
402408
mask = ~pd.isna(vals)
403-
observed = vals[mask]
409+
# Cast to int so np.diff on a bool-dtype column performs
410+
# arithmetic (1 - 0 = 1, 0 - 1 = -1) rather than XOR (which
411+
# would mask a True -> False transition).
412+
observed = vals[mask].astype(np.int64, copy=False)
404413
if len(observed) >= 2 and bool(np.any(np.diff(observed) < 0)):
405414
is_absorbing = False
406415
break

tests/test_profile_panel.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,44 @@ def test_continuous_positive_dose_does_not_fire_has_always_treated():
155155
assert "has_always_treated_units" not in _alert_codes(profile)
156156

157157

158+
def test_bool_dtype_treatment_is_binary_absorbing():
159+
"""Bool-dtype treatment columns (True/False) must classify the same
160+
way as numeric {0, 1}. The library's binary estimators validate on
161+
value support via `validate_binary`, which accepts bool because
162+
True/False coerce to 1/0 numerically. Classifying bool as
163+
"categorical" would silently route valid binary DiD panels away
164+
from the supported estimator set."""
165+
first_treat = {u: 2 for u in range(11, 21)}
166+
rows = []
167+
for u in range(1, 21):
168+
for t in range(4):
169+
treated = u in first_treat and t >= first_treat[u]
170+
rows.append({"u": u, "t": t, "tr": bool(treated), "y": float(u) + 0.1 * t})
171+
df = pd.DataFrame(rows)
172+
assert df["tr"].dtype == bool
173+
profile = profile_panel(df, unit="u", time="t", treatment="tr", outcome="y")
174+
assert profile.treatment_type == "binary_absorbing"
175+
assert profile.has_never_treated is True
176+
assert profile.has_always_treated is False
177+
assert profile.treatment_varies_within_unit is True
178+
assert profile.cohort_sizes == {2: 10}
179+
180+
181+
def test_bool_dtype_non_absorbing():
182+
"""Reversible 0 -> 1 -> 0 treatment expressed as a bool column must
183+
classify as binary_non_absorbing, same as numeric."""
184+
rows = []
185+
for u in range(1, 11):
186+
seq = [False, True, True, False, False] if u > 5 else [False] * 5
187+
for t, tr in enumerate(seq):
188+
rows.append({"u": u, "t": t, "tr": tr, "y": float(u) + 0.1 * t})
189+
df = pd.DataFrame(rows)
190+
assert df["tr"].dtype == bool
191+
profile = profile_panel(df, unit="u", time="t", treatment="tr", outcome="y")
192+
assert profile.treatment_type == "binary_non_absorbing"
193+
assert profile.has_never_treated is True
194+
195+
158196
def test_categorical_treatment_object_dtype():
159197
rows = []
160198
for u in range(1, 11):

0 commit comments

Comments
 (0)