Skip to content

Commit ce73058

Browse files
igerberclaude
andcommitted
Address local AI review (Wave 2): exact dose-invariance, tolerance-aware integer detection
P1 from local review (`treatment_dose.is_time_invariant`): - Removed `np.round(..., 8)` tolerance from `_compute_treatment_dose`'s per-unit non-zero distinct-count check. The documented contract is "per-unit non-zero doses have at most one distinct value" (exact), but the implementation was rounding to 8 decimals before comparing, silently classifying tiny-but-real dose variation as time-invariant and contradicting the docstring + CHANGELOG + autonomous guide §2. Now uses exact `np.unique(unit_nonzero).size > 1`. Added a regression test (`test_treatment_dose_distinguishes_doses_at_high_precision`) for a unit with two non-equal doses separated by 1e-9 (sub the previous rounding window) — asserts `is_time_invariant=False`. Related dead-code removal: - Removed the `len(nonzero) == 0` defensive branch in `_compute_treatment_dose`. `treatment_type == "continuous"` is reached only when the treatment column has more than two distinct values OR a 2-valued numeric outside `{0, 1}`; an all-zero numeric column is classified as `binary_absorbing` and never reaches this branch, so `nonzero` is guaranteed non-empty. Removing the branch eliminates the NaN-vs-Optional[float] inconsistency the reviewer flagged on `dose_min/max/mean`. P2 from local review (`is_integer_valued` brittleness): - Switched from `np.equal(np.mod(arr, 1.0), 0.0)` to `np.isclose(arr, np.round(arr), rtol=0.0, atol=1e-12)`. The treatment / outcome column is user input (system boundary), and CSV-roundtripped count columns commonly carry float64 representation noise (e.g., `1.0` stored as `1.0000000000000002`). Tolerance-aware integer detection is the right discipline at the boundary; downstream the `is_count_like` heuristic remains gated on this AND `pct_zeros > 0` AND `skewness > 0.5` AND `n_distinct > 2`, so isolated noise can't flip the classification. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 335de3d commit ce73058

2 files changed

Lines changed: 50 additions & 11 deletions

File tree

diff_diff/profile.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,11 @@ def _compute_outcome_shape(valid: pd.Series, outcome_dtype_kind: str) -> Optiona
613613
value_min = float(arr.min())
614614
value_max = float(arr.max())
615615

616-
is_integer_valued = bool(np.all(np.equal(np.mod(arr, 1.0), 0.0)))
616+
# Tolerance-aware integer detection: a CSV-roundtripped count column
617+
# may carry float64 representation noise (e.g., 1.0 stored as
618+
# 1.0000000000000002), and that should still classify as
619+
# integer-valued for the purpose of the count-like heuristic.
620+
is_integer_valued = bool(np.all(np.isclose(arr, np.round(arr), rtol=0.0, atol=1e-12)))
617621
is_bounded_unit = bool(np.all((arr >= 0.0) & (arr <= 1.0)))
618622

619623
skewness: Optional[float] = None
@@ -672,24 +676,28 @@ def _compute_treatment_dose(
672676
n_distinct_doses = int(col.nunique())
673677
has_zero_dose = bool((col == 0).any())
674678

679+
# `treatment_type == "continuous"` is reached only when the
680+
# treatment column has more than two distinct values OR a 2-valued
681+
# numeric outside `{0, 1}` (see `_classify_treatment`). An all-zero
682+
# numeric column is classified as `binary_absorbing` and never
683+
# reaches this branch, so `nonzero` is guaranteed non-empty.
675684
nonzero = col[col != 0]
676-
if len(nonzero) > 0:
677-
dose_min = float(nonzero.min())
678-
dose_max = float(nonzero.max())
679-
dose_mean = float(nonzero.mean())
680-
else:
681-
dose_min = float("nan")
682-
dose_max = float("nan")
683-
dose_mean = float("nan")
685+
dose_min = float(nonzero.min())
686+
dose_max = float(nonzero.max())
687+
dose_mean = float(nonzero.mean())
684688

685689
is_time_invariant = True
686690
for _, group in df.groupby(unit, sort=False):
687691
unit_doses = group[treatment].dropna().to_numpy()
688692
unit_nonzero = unit_doses[unit_doses != 0]
689693
if len(unit_nonzero) == 0:
690694
continue
691-
rounded = np.round(unit_nonzero.astype(float), 8)
692-
if int(np.unique(rounded).size) > 1:
695+
# Exact distinct-count on observed non-zero values, matching the
696+
# documented contract "per-unit non-zero doses have at most one
697+
# distinct value." No tolerance is applied: continuous-DiD
698+
# eligibility is gated downstream by `ContinuousDiD.fit()`,
699+
# which itself uses exact equality on the dose column.
700+
if int(np.unique(unit_nonzero).size) > 1:
693701
is_time_invariant = False
694702
break
695703

tests/test_profile_panel.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,37 @@ def test_treatment_dose_continuous_time_varying_within_unit():
10751075
assert dose.is_time_invariant is False
10761076

10771077

1078+
def test_treatment_dose_distinguishes_doses_at_high_precision():
1079+
"""`is_time_invariant` uses EXACT distinct-count on observed non-zero
1080+
doses; a unit with two non-equal doses must be flagged as
1081+
time-varying even when the values differ only at sub-1e-8 precision.
1082+
Guards against an earlier implementation that rounded to 8 decimals
1083+
before comparing, silently treating tiny-but-real dose variation as
1084+
time-invariant. Required by the documented contract "per-unit
1085+
non-zero doses have at most one distinct value."""
1086+
rows = []
1087+
for u in range(1, 21):
1088+
for t in range(4):
1089+
if u <= 5:
1090+
dose = 0.0
1091+
elif u == 10:
1092+
# Unit 10 has two distinct nonzero doses separated by
1093+
# 1e-9 - smaller than the previous 1e-8 rounding window.
1094+
dose = 2.5 if t < 2 else 2.5 + 1e-9
1095+
else:
1096+
dose = 2.5
1097+
rows.append({"u": u, "t": t, "tr": dose, "y": 0.0})
1098+
df = pd.DataFrame(rows)
1099+
profile = profile_panel(df, unit="u", time="t", treatment="tr", outcome="y")
1100+
dose = profile.treatment_dose
1101+
assert dose is not None
1102+
assert dose.is_time_invariant is False, (
1103+
"Expected is_time_invariant=False for a unit with non-equal "
1104+
"non-zero doses, even at sub-1e-8 precision; the field's "
1105+
"documented contract is exact distinct-count on observed values."
1106+
)
1107+
1108+
10781109
def test_treatment_dose_continuous_no_zero_dose():
10791110
"""If every unit has a strictly positive dose throughout, has_zero_dose
10801111
must be False — flagging the absence of zero-dose controls required by

0 commit comments

Comments
 (0)