Skip to content

Commit 1f044e7

Browse files
igerberclaude
andcommitted
Address PR #350 CI review round 4: P1 unused categorical levels trip balance check
**P1:** Round 3 added `observed=False` to the balance-check groupby to silence a pandas FutureWarning, but that creates a false-unbalance bug: on ordered-categorical `time_col` with extra category levels beyond the observed periods, `observed=False` materializes zero-count unit-period cells for the unused levels, and the balance check rejects the panel. Fix: switched to `observed=True`. This tells categorical groupby to count only OBSERVED unit-period cells, matching the `periods_list` (observed uniques) that the rest of the validator is keyed to. No change for numeric / datetime time columns. **Test added:** `test_ordered_categorical_with_unused_levels_accepted` declares categories `["pre0", "pre1", "pre2", "post1", "post2", "post3"]` but only observes `{"pre1", "pre2", "post1", "post2"}`; asserts the fit succeeds with `F="post1"` and `event_times=[-2, 0, 1]`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent cb11a1f commit 1f044e7

2 files changed

Lines changed: 51 additions & 3 deletions

File tree

diff_diff/had.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,9 +1110,15 @@ def _sort_key(x: Any) -> Tuple[bool, Any]:
11101110
)
11111111

11121112
# Balanced panel on the (possibly-filtered) data: every unit appears
1113-
# exactly once per period. ``observed=False`` preserves current
1114-
# behavior on categorical time columns (pandas' default is changing).
1115-
counts = data_filtered.groupby([unit_col, time_col], observed=False).size()
1113+
# exactly once per period. ``observed=True`` tells categorical
1114+
# groupby to count only OBSERVED unit-period cells. Without it, a
1115+
# time_col with an ordered-categorical dtype carrying extra unused
1116+
# category levels (beyond the periods actually present in the data)
1117+
# would expand to zero-count cells and the balance check would
1118+
# falsely reject valid panels. The rest of the validator is keyed
1119+
# to ``periods_list`` (observed unique values) so this stays
1120+
# consistent.
1121+
counts = data_filtered.groupby([unit_col, time_col], observed=True).size()
11161122
if (counts != 1).any():
11171123
n_bad = int((counts != 1).sum())
11181124
raise ValueError(

tests/test_had.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,48 @@ def test_unordered_string_time_col_rejected(self):
28652865
panel, "outcome", "dose", "period", "unit", aggregate="event_study"
28662866
)
28672867

2868+
def test_ordered_categorical_with_unused_levels_accepted(self):
2869+
"""Ordered categorical with extra unused category levels fits.
2870+
2871+
Covers CI reviewer round 4 P1: the balanced-panel check must
2872+
use ``observed=True`` on categorical groupby so unused category
2873+
levels don't expand to zero-count cells and falsely trip the
2874+
balance guard.
2875+
"""
2876+
rng = np.random.default_rng(0)
2877+
G = 40
2878+
# Observed periods: pre1, pre2, post1, post2
2879+
# Declared categories: ALSO include pre0 (unused) and post3 (unused)
2880+
all_categories = ["pre0", "pre1", "pre2", "post1", "post2", "post3"]
2881+
observed = ["pre1", "pre2", "post1", "post2"]
2882+
cat_dtype = pd.CategoricalDtype(categories=all_categories, ordered=True)
2883+
rows = []
2884+
d_post = rng.uniform(0.1, 1.0, G)
2885+
d_post[0] = 0.0
2886+
for g in range(G):
2887+
for label in observed:
2888+
dose = d_post[g] if label in ("post1", "post2") else 0.0
2889+
rows.append(
2890+
{
2891+
"unit": g,
2892+
"period": label,
2893+
"dose": dose,
2894+
"outcome": rng.standard_normal(),
2895+
}
2896+
)
2897+
panel = pd.DataFrame(rows)
2898+
panel["period"] = panel["period"].astype(cat_dtype)
2899+
with warnings.catch_warnings():
2900+
warnings.simplefilter("ignore", UserWarning)
2901+
result = HeterogeneousAdoptionDiD(design="auto").fit(
2902+
panel, "outcome", "dose", "period", "unit", aggregate="event_study"
2903+
)
2904+
# F should be post1 (first observed post-period); event_times
2905+
# should be [-2, 0, 1] (e=-1 for anchor pre2 is skipped).
2906+
assert result.F == "post1"
2907+
assert result.event_times.tolist() == [-2, 0, 1]
2908+
assert result.n_units == G
2909+
28682910
def test_ordered_categorical_time_col_accepted(self):
28692911
"""Ordered categorical time dtype passes the ordered-time check."""
28702912
rng = np.random.default_rng(0)

0 commit comments

Comments
 (0)