Skip to content

Commit 50ab3bd

Browse files
igerberclaude
andcommitted
Use pd.api.types.is_numeric_dtype for nullable dtype support
Replace np.issubdtype with pd.api.types.is_numeric_dtype so pandas nullable extension dtypes (Int64, Float64) are accepted as numeric. Add regression test with Float64 outcome column. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8ac2222 commit 50ab3bd

2 files changed

Lines changed: 16 additions & 1 deletion

File tree

diff_diff/prep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1541,7 +1541,7 @@ def aggregate_survey(
15411541
# --- Precompute full-length outcome/covariate arrays ---
15421542
n_total = len(data)
15431543
all_vars = outcome_cols + cov_cols
1544-
non_numeric = [v for v in all_vars if not np.issubdtype(data[v].dtype, np.number)]
1544+
non_numeric = [v for v in all_vars if not pd.api.types.is_numeric_dtype(data[v])]
15451545
if non_numeric:
15461546
raise ValueError(
15471547
f"Non-numeric column(s) in outcomes/covariates: {non_numeric}. "

tests/test_prep.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,6 +2463,21 @@ def test_error_non_numeric_outcome(self, micro_data, design):
24632463
survey_design=design,
24642464
)
24652465

2466+
def test_nullable_numeric_dtypes(self):
2467+
"""Pandas nullable Int64/Float64 dtypes are accepted as numeric."""
2468+
data = pd.DataFrame(
2469+
{
2470+
"geo": np.repeat(["A", "B"], 10),
2471+
"time": np.ones(20, dtype=int),
2472+
"wt": np.ones(20),
2473+
"y": pd.array(np.random.RandomState(1).normal(0, 1, 20), dtype="Float64"),
2474+
}
2475+
)
2476+
design = SurveyDesign(weights="wt")
2477+
panel, _ = aggregate_survey(data, by=["geo", "time"], outcomes="y", survey_design=design)
2478+
assert len(panel) == 2
2479+
assert panel["y_mean"].notna().all()
2480+
24662481
def test_error_empty_data(self, design):
24672482
"""Empty DataFrame raises ValueError."""
24682483
empty = pd.DataFrame(columns=["state", "year", "y", "wt", "stratum", "cluster"])

0 commit comments

Comments
 (0)