Skip to content

Commit 2ace61d

Browse files
igerberclaude
andcommitted
Address CI AI review: reject negative first_treat in ContinuousDiD
CI AI re-review flagged (P1) that the previous commit claimed "-inf will be rejected by downstream validators" in both the code comment and REGISTRY.md, but no such validator existed. After the `+inf -> 0` normalization, `first_treat < 0` units fell out of both the treated (g > 0) and never-treated (g == 0) masks, so the affected units were silently excluded from the estimator — exactly the axis-E silent failure the PR was closing. - ContinuousDiD.fit() now validates `first_treat < 0` explicitly post-normalization and raises ValueError with the row count. -inf, -2, and any other negative value are all rejected. - REGISTRY.md note rewritten to match the implemented behavior. - Existing -inf test replaced with one that asserts `pytest.raises(ValueError)` matching the row-count message, plus a positive regression test confirming +inf warning stays silent on panels with only valid 0/positive `first_treat` values. - tests/test_utils.py::test_silent_on_balanced_panel tightened: the balanced-panel silence assertion now filters on any warning containing "dropped", so a regression that changed the warning label would no longer hide a genuine drop signal. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2d344b2 commit 2ace61d

4 files changed

Lines changed: 55 additions & 18 deletions

File tree

diff_diff/continuous_did.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ def fit(
233233
# recategorization here would shift the control composition (axis-E
234234
# silent coercion). Only positive infinity is recoded (to match the
235235
# existing `.replace([np.inf, float("inf")], 0)` semantics on the
236-
# next line); `-inf` is neither counted here nor recoded, so a
237-
# downstream validator will reject it if present.
238-
inf_mask = np.isposinf(df[first_treat].values)
236+
# next line).
237+
first_treat_vals = df[first_treat].values
238+
inf_mask = np.isposinf(first_treat_vals)
239239
n_inf_first_treat = int(inf_mask.sum())
240240
if n_inf_first_treat > 0:
241241
warnings.warn(
@@ -245,6 +245,19 @@ def fit(
245245
UserWarning,
246246
stacklevel=2,
247247
)
248+
# Reject negative first_treat values (including -inf) explicitly.
249+
# Without this guard they would survive preprocessing but fall out of
250+
# both the treated (g > 0) and never-treated (g == 0) masks, silently
251+
# excluding the affected units.
252+
negative_mask = first_treat_vals < 0
253+
n_negative_first_treat = int(negative_mask.sum())
254+
if n_negative_first_treat > 0:
255+
raise ValueError(
256+
f"{n_negative_first_treat} row(s) have negative '{first_treat}' "
257+
f"values (including -inf). Valid values are 0 (never-treated) "
258+
f"or a positive treatment period; such units would otherwise "
259+
f"be silently excluded from both treated and control pools."
260+
)
248261
df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0)
249262

250263
# Drop units with positive first_treat but zero dose (R convention)

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ See `docs/methodology/continuous-did.md` Section 4 for full details.
720720
- [ ] Lowest-dose-as-control (Remark 3.1)
721721
- [x] Survey design support (Phase 3): weighted B-spline OLS, TSL on influence functions; bootstrap+survey supported (Phase 6)
722722
- **Note:** ContinuousDiD bootstrap with survey weights supported (Phase 6) via PSU-level multiplier weights
723-
- **Note:** The R-style convention of coding never-treated units as `first_treat=inf` is still accepted and normalized to `first_treat=0` internally, but the estimator now emits a `UserWarning` reporting the row count so the silent recategorization is surfaced (axis-E silent coercion under the Phase 2 audit). Only `+inf` is recoded (matching the R convention); `-inf` passes through untouched and will be rejected by downstream validators. Pass `0` directly to avoid the warning.
723+
- **Note:** The R-style convention of coding never-treated units as `first_treat=inf` is still accepted and normalized to `first_treat=0` internally, but the estimator now emits a `UserWarning` reporting the row count so the silent recategorization is surfaced (axis-E silent coercion under the Phase 2 audit). Only `+inf` is recoded (matching the R convention). Any **negative** `first_treat` value (including `-inf`) raises `ValueError` with the row count, since such units would otherwise silently fall out of both the treated (`g > 0`) and never-treated (`g == 0`) masks. Pass `0` directly for never-treated units to avoid the warning.
724724
- **Note:** Rows where `first_treat=0` (never-treated) carry a nonzero `dose` are silently zeroed for internal consistency (never-treated cells must have `D=0` in the dose response). The estimator now emits a `UserWarning` with the affected row count before the zeroing, so unintended nonzero doses on never-treated rows are no longer absorbed without a signal (axis-E silent coercion).
725725

726726
---

tests/test_continuous_did.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -723,16 +723,20 @@ def test_clean_never_treated_doses_silent(self):
723723
]
724724
assert coerce_warnings == []
725725

726-
def test_negative_inf_first_treat_does_not_trigger_recategorization_warning(self):
727-
"""-inf first_treat is NOT recoded to 0 by `.replace([inf, float("inf")], 0)`,
728-
so the recategorization warning (which used to count both +inf and -inf
729-
via np.isinf) must not fire for -inf rows."""
730-
import warnings
726+
def test_negative_first_treat_raises_with_row_count(self):
727+
"""Negative `first_treat` (including -inf) must raise ValueError with
728+
the affected row count. Without this guard the affected units fall
729+
out of both the treated (g > 0) and never-treated (g == 0) masks and
730+
are silently excluded from the estimator."""
731731
rows = []
732732
for unit in range(4):
733-
# Unit 0 carries -inf (not recoded, so downstream validation should
734-
# see it as-is). Others are untreated with dose=0.
735-
ft = -np.inf if unit == 0 else 0.0
733+
# Unit 0: -inf. Unit 1: -2. Others: valid (0 or positive).
734+
if unit == 0:
735+
ft = -np.inf
736+
elif unit == 1:
737+
ft = -2.0
738+
else:
739+
ft = 0.0
736740
for t in range(1, 4):
737741
rows.append({
738742
"unit": unit, "period": t, "outcome": float(unit + t),
@@ -741,6 +745,28 @@ def test_negative_inf_first_treat_does_not_trigger_recategorization_warning(self
741745
data = pd.DataFrame(rows)
742746
est = ContinuousDiD()
743747

748+
with pytest.raises(
749+
ValueError,
750+
match=r"6 row\(s\) have negative 'first_treat' values",
751+
):
752+
est.fit(data, "outcome", "unit", "period", "first_treat", "dose")
753+
754+
def test_positive_inf_warning_silent_when_no_inf(self):
755+
"""+inf warning is gated on +inf rows only; panels with only valid
756+
non-negative values (including just 0 and positive periods) must
757+
never trigger the recategorization warning."""
758+
import warnings
759+
rows = []
760+
for unit in range(4):
761+
ft = 0.0 if unit < 2 else 2.0
762+
for t in range(1, 4):
763+
rows.append({
764+
"unit": unit, "period": t, "outcome": float(unit + t),
765+
"first_treat": ft, "dose": 0.0 if unit < 2 else 1.0,
766+
})
767+
data = pd.DataFrame(rows)
768+
est = ContinuousDiD()
769+
744770
with warnings.catch_warnings(record=True) as w:
745771
warnings.simplefilter("always")
746772
try:
@@ -749,9 +775,7 @@ def test_negative_inf_first_treat_does_not_trigger_recategorization_warning(self
749775
pass
750776

751777
inf_warnings = [x for x in w if "inf in 'first_treat'" in str(x.message)]
752-
assert inf_warnings == [], (
753-
"-inf must not trigger the +inf recategorization warning"
754-
)
778+
assert inf_warnings == []
755779

756780
def test_inf_first_treat_warning_counts_rows_not_units(self):
757781
"""The warning counts affected rows (not units). On a panel with

tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -849,9 +849,9 @@ def test_silent_on_balanced_panel(self):
849849
treatment_group="treated", unit="unit",
850850
)
851851

852-
drop_warnings = [
853-
x for x in w if "check_parallel_trends dropped" in str(x.message)
854-
]
852+
# Generic filter on "dropped" catches both the old and new label so a
853+
# regression in the label wouldn't hide a real silent-drop warning.
854+
drop_warnings = [x for x in w if "dropped" in str(x.message).lower()]
855855
assert drop_warnings == []
856856

857857
def test_warns_on_nan_outcomes_with_excess_drop_count(self):

0 commit comments

Comments
 (0)