Skip to content

Commit 8dd0e41

Browse files
igerberclaude
andcommitted
Address PR #401 R5 review (1 P2 transactional set_params)
R5 verdict was "Looks good" with one P2 code-quality regression flagged. The previous round's `test_set_params_partial_mutation_recoverable` documented inconsistent intermediate state as expected behavior: on a mutex violation, `set_params()` mutated `self` before validation, so a failed call left both `by_path` and `paths_of_interest` populated. `fit()` did not re-check the mutex and `_enumerate_treatment_paths()` silently preferred `paths_of_interest` when both were set, so subsequent fits ran with stale selector state unless the caller manually repaired the object. Fix: make `set_params()` transactional. Snapshot the pre-call values for the keys being set, attempt mutation + validation, and roll back to the snapshot on any exception before re-raising. This mirrors sklearn-style atomic semantics. The validation block was extracted into a private `_validate_invariants()` helper so __init__ and set_params can share the rules without code duplication risking drift. Replaces `test_set_params_partial_mutation_recoverable` (which asserted the buggy partial-mutation behavior) with `test_set_params_failed_validation_is_transactional`, asserting that: 1. After a failed `set_params()`, `get_params()` returns bit-identical to the pre-call snapshot. 2. Subsequent valid `set_params()` succeeds against the rolled- back state. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 053d2c2 commit 8dd0e41

2 files changed

Lines changed: 38 additions & 13 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -779,14 +779,30 @@ def set_params(self, **params: Any) -> "ChaisemartinDHaultfoeuille":
779779
"""
780780
Set estimator parameters (sklearn-compatible).
781781
782-
Re-runs the same validation rules as ``__init__`` so invalid
783-
parameter combinations cannot be introduced after construction.
782+
**Transactional**: validation runs after the candidate mutations,
783+
and if any rule fails the estimator state is rolled back to its
784+
pre-call values before the exception is re-raised. Callers can
785+
therefore retry with corrected params on the same instance
786+
without repairing inconsistent intermediate state.
784787
"""
785-
for key, value in params.items():
788+
# Snapshot current values for the keys we are about to set so
789+
# we can roll back on validation failure (transactional semantics).
790+
for key in params:
786791
if not hasattr(self, key):
787792
raise ValueError(f"Unknown parameter: {key}")
788-
setattr(self, key, value)
793+
snapshot = {key: getattr(self, key) for key in params}
794+
try:
795+
for key, value in params.items():
796+
setattr(self, key, value)
797+
self._validate_invariants()
798+
except Exception:
799+
for key, value in snapshot.items():
800+
setattr(self, key, value)
801+
raise
802+
return self
789803

804+
def _validate_invariants(self) -> None:
805+
"""Run the post-mutation validation rules. Mirrors `__init__`."""
790806
# Re-run __init__ validation rules so the post-set state is valid.
791807
if self.rank_deficient_action not in ("warn", "error", "silent"):
792808
raise ValueError(
@@ -834,7 +850,6 @@ def set_params(self, **params: Any) -> "ChaisemartinDHaultfoeuille":
834850
f"clustering is reserved for a future phase. See REGISTRY.md "
835851
f"ChaisemartinDHaultfoeuille section for the full contract."
836852
)
837-
return self
838853

839854
# ------------------------------------------------------------------
840855
# fit

tests/test_chaisemartin_dhaultfoeuille.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8519,18 +8519,28 @@ def test_set_params_re_validates_mutex_poi_added(self):
85198519
with pytest.raises(ValueError, match="mutually exclusive"):
85208520
est.set_params(paths_of_interest=[(0, 1, 1, 1)])
85218521

8522-
def test_set_params_partial_mutation_recoverable(self):
8523-
"""After mutex set_params raises mid-mutation, recovery is possible."""
8522+
def test_set_params_failed_validation_is_transactional(self):
8523+
"""A failed `set_params()` must leave estimator state unchanged
8524+
(regression for R5 P2 finding: prior implementation mutated
8525+
before validation, leaving both selectors populated when the
8526+
mutex check raised, which `fit()` then silently consumed)."""
85248527
est = ChaisemartinDHaultfoeuille(paths_of_interest=[(0, 1, 1, 1)])
8525-
# set_params mutates self before validation; mutex raise leaves both set
8528+
# Capture pre-call state.
8529+
before = est.get_params()
8530+
# Mutex violation: by_path AND paths_of_interest both non-None.
85268531
with pytest.raises(ValueError, match="mutually exclusive"):
85278532
est.set_params(by_path=2)
8528-
# User can recover by clearing one of them.
8529-
est.set_params(by_path=None)
8530-
# Instance now valid; subsequent calls don't raise.
8533+
# Post-failure state is rolled back to pre-call.
8534+
after = est.get_params()
8535+
assert after == before, (
8536+
f"set_params() rollback failed: by_path={after['by_path']}, "
8537+
f"paths_of_interest={after['paths_of_interest']}"
8538+
)
8539+
# Subsequent valid set_params() succeeds against rolled-back state.
8540+
est.set_params(by_path=2, paths_of_interest=None)
85318541
params = est.get_params()
8532-
assert params["by_path"] is None
8533-
assert params["paths_of_interest"] == [(0, 1, 1, 1)]
8542+
assert params["by_path"] == 2
8543+
assert params["paths_of_interest"] is None
85348544

85358545
def test_get_params_includes_paths_of_interest(self):
85368546
est = ChaisemartinDHaultfoeuille(

0 commit comments

Comments
 (0)