Skip to content

Commit 67bc6db

Browse files
igerberclaude
andcommitted
Fix epv_summary column schema on empty results, update test comments
Ensure epv_summary(show_all=False) returns DataFrame with correct column schema even when no entries have low EPV, across all three results classes. Fix remaining test comments to use intercept-excluded EPV arithmetic. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7404815 commit 67bc6db

4 files changed

Lines changed: 8 additions & 5 deletions

File tree

diff_diff/staggered_results.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,8 @@ def epv_summary(self, show_all: bool = False) -> pd.DataFrame:
310310
"is_low": diag.get("is_low", False),
311311
}
312312
)
313-
return pd.DataFrame(rows)
313+
cols = ["group", "time", "epv", "n_events", "n_params", "is_low"]
314+
return pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
314315

315316
def print_summary(self, alpha: Optional[float] = None) -> None:
316317
"""Print summary to stdout."""

diff_diff/staggered_triple_diff_results.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def epv_summary(self, show_all: bool = False) -> pd.DataFrame:
286286
"is_low": diag.get("is_low", False),
287287
}
288288
)
289-
return pd.DataFrame(rows)
289+
cols = ["group", "time", "epv", "n_events", "n_params", "is_low"]
290+
return pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
290291

291292
def to_dataframe(self, level: str = "group_time") -> pd.DataFrame:
292293
"""

diff_diff/triple_diff.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,8 @@ def epv_summary(self, show_all: bool = False) -> pd.DataFrame:
335335
"is_low": diag.get("is_low", False),
336336
}
337337
)
338-
return pd.DataFrame(rows)
338+
cols = ["subgroup", "epv", "n_events", "n_params", "is_low"]
339+
return pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
339340

340341

341342
# =============================================================================

tests/test_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,11 +1728,11 @@ def test_epv_threshold_configurable(self):
17281728
# 15 events, 2 predictor variables → EPV = 7.5
17291729
y = np.concatenate([np.ones(15), np.zeros(n - 15)])
17301730

1731-
# Default threshold 10 → should warn (EPV=5)
1731+
# Default threshold 10 → should warn (EPV=7.5 < 10)
17321732
with pytest.warns(UserWarning, match="Low Events Per Variable"):
17331733
solve_logit(X, y, epv_threshold=10)
17341734

1735-
# Threshold 3 → should not warn (EPV=5 >= 3)
1735+
# Threshold 3 → should not warn (EPV=7.5 >= 3)
17361736
import warnings
17371737

17381738
with warnings.catch_warnings(record=True) as w:

0 commit comments

Comments
 (0)