diff --git a/aeon/base/tests/test_base.py b/aeon/base/tests/test_base.py index fcb4ac1907..14f4949ba7 100644 --- a/aeon/base/tests/test_base.py +++ b/aeon/base/tests/test_base.py @@ -137,6 +137,7 @@ def test_clone_function(): "capability:missing_values": True, "capability:multithreading": False, "capability:multivariate": True, + "capability:predict_proba": True, "capability:train_estimate": False, "capability:unequal_length": True, "capability:univariate": True, diff --git a/aeon/classification/base.py b/aeon/classification/base.py index f0609c9c72..26d11bf5ee 100644 --- a/aeon/classification/base.py +++ b/aeon/classification/base.py @@ -58,6 +58,7 @@ class BaseClassifier(ClassifierMixin, BaseCollectionEstimator): "fit_is_empty": False, "capability:train_estimate": False, "capability:contractable": False, + "capability:predict_proba": True, } @abstractmethod diff --git a/aeon/classification/compose/_pipeline.py b/aeon/classification/compose/_pipeline.py index ce176b810c..9f8c9ebc88 100644 --- a/aeon/classification/compose/_pipeline.py +++ b/aeon/classification/compose/_pipeline.py @@ -86,6 +86,13 @@ def __init__(self, transformers, classifier, random_state=None): transformers=transformers, _estimator=classifier, random_state=random_state ) + if hasattr(classifier, "get_tag"): + can_proba = classifier.get_tag("capability:predict_proba", True) + else: + can_proba = hasattr(classifier, "predict_proba") + + self.set_tags(**{"capability:predict_proba": can_proba}) + @classmethod def _get_test_params(cls, parameter_set="default"): """Return testing parameter settings for the estimator. diff --git a/aeon/classification/distance_based/_time_series_neighbors.py b/aeon/classification/distance_based/_time_series_neighbors.py index b92f3ebbe1..ea2743bc01 100644 --- a/aeon/classification/distance_based/_time_series_neighbors.py +++ b/aeon/classification/distance_based/_time_series_neighbors.py @@ -105,6 +105,9 @@ def __init__( super().__init__() + if n_neighbors == 1: + self.set_tags(**{"capability:predict_proba": False}) + def _fit(self, X, y): """ Fit the model using ``X`` as training data and ``y`` as target values. diff --git a/aeon/classification/tests/test_predict_proba_tag.py b/aeon/classification/tests/test_predict_proba_tag.py new file mode 100644 index 0000000000..37e08ea6ed --- /dev/null +++ b/aeon/classification/tests/test_predict_proba_tag.py @@ -0,0 +1,31 @@ +"""Tests for the capability:predict_proba tag.""" + +from aeon.classification.compose import ClassifierPipeline +from aeon.classification.distance_based import KNeighborsTimeSeriesClassifier +from aeon.transformations.collection.unequal_length import Truncator + + +def test_predict_proba_tag_knn(): + """Test that KNN correctly sets the tag based on n_neighbors.""" + knn1 = KNeighborsTimeSeriesClassifier(n_neighbors=1) + assert not knn1.get_tag("capability:predict_proba") + + knn5 = KNeighborsTimeSeriesClassifier(n_neighbors=5) + assert knn5.get_tag("capability:predict_proba") + + +def test_predict_proba_tag_pipeline(): + """Test that ClassifierPipeline correctly inherits the tag.""" + transformer = Truncator(truncated_length=5) + + pipe1 = ClassifierPipeline( + transformers=[transformer], + classifier=KNeighborsTimeSeriesClassifier(n_neighbors=1), + ) + assert not pipe1.get_tag("capability:predict_proba") + + pipe5 = ClassifierPipeline( + transformers=[transformer], + classifier=KNeighborsTimeSeriesClassifier(n_neighbors=5), + ) + assert pipe5.get_tag("capability:predict_proba") diff --git a/aeon/utils/tags/_tags.py b/aeon/utils/tags/_tags.py index ff9274863b..b1ced0c446 100644 --- a/aeon/utils/tags/_tags.py +++ b/aeon/utils/tags/_tags.py @@ -124,6 +124,11 @@ class : identifier for the base class of objects this tag applies to "type": "bool", "description": "Can the estimator limiting max fit time?", }, + "capability:predict_proba": { + "class": "classifier", + "type": "bool", + "description": "Does the estimator support granular predict_proba estimates?", + }, "capability:exogenous": { "class": ["forecaster"], "type": "bool", diff --git a/docs/changelogs/v1.3.md b/docs/changelogs/v1.3.md index eae8566a67..a77e1d1225 100644 --- a/docs/changelogs/v1.3.md +++ b/docs/changelogs/v1.3.md @@ -39,6 +39,7 @@ September 2025 - [ENH] KNN n_jobs and updated kneighbours method ({pr}`2578`) {user}`chrisholder` - [ENH] Refactor signature code ({pr}`2943`) {user}`TonyBagnall` - [ENH] Change seed to random_state ({pr}`3031`) {user}`TonyBagnall` +- [ENH] Add capability:predict_proba tag to BaseClassifier ({pr}`3127`) {user}`Nithurshen` ## Clustering