Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions aisp/base/core/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base class for parameter introspection compatible with the scikit-learn API."""

from __future__ import annotations

from inspect import signature

class Base:
"""
Expand Down Expand Up @@ -43,7 +43,7 @@ def get_params(self, deep: bool = True) -> dict: # pylint: disable=W0613
Dictionary containing the object's attributes that do not start with "_".
"""
return {
key: value
for key, value in self.__dict__.items()
if not key.startswith("_")
}
key: getattr(self, key)
for key, _ in signature(self.__init__).parameters.items()
if key != "self" and not key.startswith("_") and hasattr(self, key)
}
13 changes: 8 additions & 5 deletions aisp/base/core/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ class TestBase:

def setup_method(self):
"""Set up a Base instance with example attributes before each test."""
self.obj = Base()
self.obj.alpha = 1
self.obj.beta = 2
class BaseFixture(Base):
"""Fixture class for Base tests."""
def __init__(self, alpha=1, beta=2):
self.alpha = alpha
self.beta = beta

self.obj = BaseFixture()

def test_set_and_get_params_basic(self):
"""Test setting parameters using set_params and retrieving them with get_params."""
Expand All @@ -28,8 +32,7 @@ def test_get_params_excludes_private(self):

params = self.obj.get_params()
assert "_private" not in params
assert "public" in params
assert params["public"] == "ok"
assert "public" not in params

def test_set_params_updates_existing(self):
"""Test that set_params updates existing attributes correctly."""
Expand Down
4 changes: 2 additions & 2 deletions aisp/csa/_ai_recognition_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
resource_amplified: float = 1.0,
metric: MetricType = "euclidean",
seed: Optional[int] = None,
**kwargs,
p: float = 2.0,
) -> None:
self.n_resources: float = sanitize_param(n_resources, 10, lambda x: x >= 1)
self.rate_mc_init: float = sanitize_param(
Expand All @@ -153,7 +153,7 @@ def __init__(

self.metric = sanitize_choice(metric, ["manhattan", "minkowski"], "euclidean")

self.p: np.float64 = np.float64(kwargs.get("p", 2.0))
self.p = p

self._cells_memory: Optional[Dict[str | int, list[BCell]]] = None
self._all_class_cell_vectors: Optional[List[Tuple[Any, np.ndarray]]] = None
Expand Down
6 changes: 3 additions & 3 deletions aisp/csa/tests/test_ai_recognition_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TestAIRSBinary:
def test_fit_and_predict(self, b_airs_data):
"""Should the fit and predict methods of the AIRS class."""
X, y, seed = b_airs_data
model = AIRS(algorithm="binary-features", seed=seed)
model = AIRS(seed=seed)
model.fit(X, y, verbose=False)
predictions = model.predict(X)
assert predictions is not None
Expand All @@ -44,7 +44,7 @@ def test_fit_and_predict(self, b_airs_data):
def test_predict_raises_feature_dimension_mismatch(self, b_airs_data):
"""Should raise FeatureDimensionMismatch when prediction input has wrong dimensions."""
X, y, seed = b_airs_data
model = AIRS(algorithm="binary-features", seed=seed)
model = AIRS(seed=seed)
model.fit(X, y, verbose=False)
x_invalid = np.random.choice([True, False], size=(5, 5))

Expand All @@ -54,7 +54,7 @@ def test_predict_raises_feature_dimension_mismatch(self, b_airs_data):
def test_score_range(self, b_airs_data):
"""Score should return a value between 0 and 1."""
X, y, seed = b_airs_data
model = AIRS(algorithm="binary-features", seed=seed)
model = AIRS(seed=seed)
model.fit(X, y, verbose=False)
score = model.score(X, y)
assert isinstance(score, float)
Expand Down
Loading