Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aeon/base/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def test_clone_function():
"capability:missing_values": True,
"capability:multithreading": False,
"capability:multivariate": True,
"capability:predict_proba": True,
"capability:train_estimate": False,
"capability:unequal_length": True,
"capability:univariate": True,
Expand Down
1 change: 1 addition & 0 deletions aeon/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class BaseClassifier(ClassifierMixin, BaseCollectionEstimator):
"fit_is_empty": False,
"capability:train_estimate": False,
"capability:contractable": False,
"capability:predict_proba": True,
}

@abstractmethod
Expand Down
7 changes: 7 additions & 0 deletions aeon/classification/compose/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def __init__(self, transformers, classifier, random_state=None):
transformers=transformers, _estimator=classifier, random_state=random_state
)

if hasattr(classifier, "get_tag"):
can_proba = classifier.get_tag("capability:predict_proba", True)
else:
can_proba = hasattr(classifier, "predict_proba")

self.set_tags(**{"capability:predict_proba": can_proba})

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def __init__(

super().__init__()

if n_neighbors == 1:
self.set_tags(**{"capability:predict_proba": False})

def _fit(self, X, y):
"""
Fit the model using ``X`` as training data and ``y`` as target values.
Expand Down
31 changes: 31 additions & 0 deletions aeon/classification/tests/test_predict_proba_tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Tests for the capability:predict_proba tag."""

from aeon.classification.compose import ClassifierPipeline
from aeon.classification.distance_based import KNeighborsTimeSeriesClassifier
from aeon.transformations.collection.unequal_length import Truncator


def test_predict_proba_tag_knn():
"""Test that KNN correctly sets the tag based on n_neighbors."""
knn1 = KNeighborsTimeSeriesClassifier(n_neighbors=1)
assert not knn1.get_tag("capability:predict_proba")

knn5 = KNeighborsTimeSeriesClassifier(n_neighbors=5)
assert knn5.get_tag("capability:predict_proba")


def test_predict_proba_tag_pipeline():
"""Test that ClassifierPipeline correctly inherits the tag."""
transformer = Truncator(truncated_length=5)

pipe1 = ClassifierPipeline(
transformers=[transformer],
classifier=KNeighborsTimeSeriesClassifier(n_neighbors=1),
)
assert not pipe1.get_tag("capability:predict_proba")

pipe5 = ClassifierPipeline(
transformers=[transformer],
classifier=KNeighborsTimeSeriesClassifier(n_neighbors=5),
)
assert pipe5.get_tag("capability:predict_proba")
5 changes: 5 additions & 0 deletions aeon/utils/tags/_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ class : identifier for the base class of objects this tag applies to
"type": "bool",
"description": "Can the estimator limiting max fit time?",
},
"capability:predict_proba": {
"class": "classifier",
"type": "bool",
"description": "Does the estimator support granular predict_proba estimates?",
},
"capability:exogenous": {
"class": ["forecaster"],
"type": "bool",
Expand Down
1 change: 1 addition & 0 deletions docs/changelogs/v1.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ September 2025
- [ENH] KNN n_jobs and updated kneighbours method ({pr}`2578`) {user}`chrisholder`
- [ENH] Refactor signature code ({pr}`2943`) {user}`TonyBagnall`
- [ENH] Change seed to random_state ({pr}`3031`) {user}`TonyBagnall`
- [ENH] Add capability:predict_proba tag to BaseClassifier ({pr}`3127`) {user}`Nithurshen`

## Clustering

Expand Down