Skip to content

Commit 98b8201

Browse files
igerberclaude
andcommitted
Address PR #346 CI review round 10: P1 sklearn-compatible get/set_params
**P1 (Code Quality): sklearn parameter contract** The class docstring advertises `sklearn.base.clone(est)` compatibility, but the actual methods did not match sklearn's `BaseEstimator.get_params(deep=True)` / `set_params(**params)` surface: - `get_params()` did not accept a `deep` keyword. sklearn's `clone` calls `get_params(deep=False)`, so any integration would have failed with a TypeError on the missing kwarg. - `set_params()` validated keys with `hasattr(self, key)`. That would silently accept non-constructor attribute names like `fit`, and a typo or malicious kwargs dict could overwrite estimator methods. Fix: - `get_params(self, deep: bool = True)` matches sklearn's signature. `deep` is accepted for compat; this estimator has no nested sub-estimators, so `deep=True` and `deep=False` return the same dict. `del deep` documents the no-op explicitly and silences unused-arg linters. - `set_params(**params)` now restricts to keys from `get_params()`. Non-constructor attribute names (including method names like `fit` and dunder/private attrs) raise `ValueError`. **Tests (+4 regression):** - test_set_params_rejects_method_names: `set_params(fit=...)` raises and `est.fit` stays callable. - test_set_params_rejects_private_attrs: `_internal=42` raises. - test_get_params_accepts_deep_keyword: `deep=True`, `deep=False`, and no-arg all return the same dict. - test_sklearn_clone_round_trip_if_available: `sklearn.base.clone` round-trips the estimator; gated on `pytest.importorskip("sklearn")` so it skips cleanly when sklearn is not in the test matrix. Targeted regression: 155 HAD tests (+ 1 skipped) + 534 total (+1 skipped) across Phase 1 and adjacent surfaces, all green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a3553ed commit 98b8201

2 files changed

Lines changed: 64 additions & 8 deletions

File tree

diff_diff/had.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -994,13 +994,23 @@ def _validate_constructor_args(self) -> None:
994994
f"or None."
995995
)
996996

997-
def get_params(self) -> Dict[str, Any]:
997+
def get_params(self, deep: bool = True) -> Dict[str, Any]:
998998
"""Return the raw constructor parameters (sklearn-compatible).
999999
1000-
Preserves the user's original inputs - in particular, ``design``
1001-
returns ``"auto"`` when the user set it to ``"auto"`` (even after
1002-
fit), so ``sklearn.base.clone(est)`` round-trips exactly.
1000+
Matches the :meth:`sklearn.base.BaseEstimator.get_params`
1001+
signature. Preserves the user's original inputs - in particular,
1002+
``design`` returns ``"auto"`` when the user set it to ``"auto"``
1003+
(even after fit), so ``sklearn.base.clone(est)`` round-trips
1004+
exactly.
1005+
1006+
Parameters
1007+
----------
1008+
deep : bool, default=True
1009+
Accepted for sklearn-contract compatibility. This estimator
1010+
has no nested sub-estimator parameters, so ``deep=False``
1011+
and ``deep=True`` return the same dict.
10031012
"""
1013+
del deep # accepted for compat; this estimator has no nested params
10041014
return {
10051015
"design": self.design,
10061016
"d_lower": self.d_lower,
@@ -1012,12 +1022,18 @@ def get_params(self) -> Dict[str, Any]:
10121022
}
10131023

10141024
def set_params(self, **params: Any) -> "HeterogeneousAdoptionDiD":
1015-
"""Set estimator parameters and return self (sklearn-compatible)."""
1025+
"""Set estimator parameters and return self (sklearn-compatible).
1026+
1027+
Only keys returned by :meth:`get_params` are accepted. Passing
1028+
any other attribute name (including method names like ``fit``)
1029+
raises ``ValueError`` so the estimator cannot be silently
1030+
corrupted by a mistyped or attacker-supplied key.
1031+
"""
1032+
valid_keys = set(self.get_params().keys())
10161033
for key, value in params.items():
1017-
if not hasattr(self, key):
1034+
if key not in valid_keys:
10181035
raise ValueError(
1019-
f"Invalid parameter: {key}. Valid parameters: "
1020-
f"{list(self.get_params().keys())}."
1036+
f"Invalid parameter: {key!r}. Valid parameters: " f"{sorted(valid_keys)}."
10211037
)
10221038
setattr(self, key, value)
10231039
self._validate_constructor_args()

tests/test_had.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,46 @@ def test_set_params_invalid_key_raises(self):
849849
with pytest.raises(ValueError, match="Invalid parameter"):
850850
est.set_params(not_a_param=True)
851851

852+
def test_set_params_rejects_method_names(self):
853+
"""Review P1 round 10: set_params must restrict to constructor keys,
854+
not any hasattr-able name. Method names like 'fit' must raise,
855+
else they would silently overwrite the method.
856+
"""
857+
est = HeterogeneousAdoptionDiD()
858+
with pytest.raises(ValueError, match="Invalid parameter"):
859+
est.set_params(fit="not_a_method")
860+
# sanity: fit is still callable on the class
861+
assert callable(est.fit)
862+
863+
def test_set_params_rejects_private_attrs(self):
864+
"""Internal-looking attribute names must also raise."""
865+
est = HeterogeneousAdoptionDiD()
866+
with pytest.raises(ValueError, match="Invalid parameter"):
867+
est.set_params(_internal=42)
868+
869+
def test_get_params_accepts_deep_keyword(self):
870+
"""Review P1 round 10: get_params must match sklearn's signature.
871+
872+
sklearn.base.BaseEstimator.get_params(deep=True). This estimator
873+
has no nested sub-estimators, so deep=True and deep=False return
874+
the same dict, but the keyword must be accepted.
875+
"""
876+
est = HeterogeneousAdoptionDiD(design="continuous_at_zero", alpha=0.1)
877+
params_default = est.get_params()
878+
params_deep_true = est.get_params(deep=True)
879+
params_deep_false = est.get_params(deep=False)
880+
assert params_default == params_deep_true == params_deep_false
881+
882+
def test_sklearn_clone_round_trip_if_available(self):
883+
"""If sklearn is installed, sklearn.base.clone round-trips the estimator."""
884+
sklearn_base = pytest.importorskip("sklearn.base")
885+
est = HeterogeneousAdoptionDiD(design="auto", alpha=0.1, kernel="triangular")
886+
cloned = sklearn_base.clone(est)
887+
assert cloned.get_params() == est.get_params()
888+
assert cloned is not est
889+
# clone produces a fresh instance of the same class.
890+
assert type(cloned) is type(est)
891+
852892
def test_set_params_invalid_design_raises(self):
853893
est = HeterogeneousAdoptionDiD()
854894
with pytest.raises(ValueError, match="design"):

0 commit comments

Comments
 (0)