Skip to content

Commit 735e6ec

Browse files
authored
[Validate] Pass Evaluation Function arguments with EvaluationCriteria (#229)
* Pass eval_func_arguments to backend with EvaluationCriteria * Add better error message for scenario_test misconfiguration and arguments to all public functions * Update defaults to match metrics * Address @phil-scale comments! * Add examples to configuration functions and clear up class naming * Fix rebase errors * Another rebasing error bites the dust * 🤦‍♂️
1 parent c9f1f59 commit 735e6ec

File tree

6 files changed

+238
-39
lines changed

6 files changed

+238
-39
lines changed

nucleus/metrics/categorization_metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def __init__(
143143
):
144144
"""
145145
Args:
146-
confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. Must be in [0, 1]. Default 0.0
146+
confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation.
147+
Must be in [0, 1]. Default 0.0
147148
f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \
148149
default='macro'
149150
This parameter is required for multiclass/multilabel targets.

nucleus/validate/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from .data_transfer_objects.eval_function import GetEvalFunctions
88
from .data_transfer_objects.scenario_test import CreateScenarioTestRequest
99
from .errors import CreateScenarioTestError
10-
from .eval_functions.available_eval_functions import (
11-
AvailableEvalFunctions,
12-
EvalFunction,
13-
)
10+
from .eval_functions.available_eval_functions import AvailableEvalFunctions
11+
from .eval_functions.base_eval_function import EvalFunctionConfig
1412
from .scenario_test import ScenarioTest
1513

1614
SUCCESS_KEY = "success"
@@ -36,7 +34,8 @@ def eval_functions(self) -> AvailableEvalFunctions:
3634
import nucleus
3735
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
3836
39-
scenario_test_criterion = client.validate.eval_functions.bbox_iou() > 0.5 # Creates an EvaluationCriterion by comparison
37+
# Creates an EvaluationCriterion by using a comparison op
38+
scenario_test_criterion = client.validate.eval_functions.bbox_iou() > 0.5
4039
4140
Returns:
4241
:class:`AvailableEvalFunctions`: A container for all the available eval functions
@@ -51,7 +50,7 @@ def create_scenario_test(
5150
self,
5251
name: str,
5352
slice_id: str,
54-
evaluation_functions: List[EvalFunction],
53+
evaluation_functions: List[EvalFunctionConfig],
5554
) -> ScenarioTest:
5655
"""Creates a new Scenario Test from an existing Nucleus :class:`Slice`:. ::
5756
@@ -78,6 +77,7 @@ def create_scenario_test(
7877
"Must pass an evaluation_function to the scenario test! I.e. "
7978
"evaluation_functions=[client.validate.eval_functions.bbox_iou()]"
8079
)
80+
8181
response = self.connection.post(
8282
CreateScenarioTestRequest(
8383
name=name,

nucleus/validate/data_transfer_objects/eval_function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import Any, Dict, List, Optional
22

33
from pydantic import validator
44

@@ -50,12 +50,14 @@ class EvaluationCriterion(ImmutableModel):
5050
eval_function_id (str): ID of evaluation function
5151
threshold_comparison (:class:`ThresholdComparison`): comparator for evaluation. i.e. threshold=0.5 and threshold_comparator > implies that a test only passes if score > 0.5.
5252
threshold (float): numerical threshold that together with threshold comparison, defines success criteria for test evaluation.
53+
eval_func_arguments: Arguments to pass to the eval function constructor
5354
"""
5455

5556
# TODO: Having only eval_function_id hurts readability -> Add function name
5657
eval_function_id: str
5758
threshold_comparison: ThresholdComparison
5859
threshold: float
60+
eval_func_arguments: Dict[str, Any]
5961

6062
@validator("eval_function_id")
6163
def valid_eval_function_id(cls, v): # pylint: disable=no-self-argument

nucleus/validate/eval_functions/available_eval_functions.py

Lines changed: 209 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,234 @@
11
import itertools
2-
from typing import Callable, Dict, List, Type, Union
2+
from typing import Callable, Dict, List, Optional, Union
33

44
from nucleus.logger import logger
5-
from nucleus.validate.eval_functions.base_eval_function import BaseEvalFunction
5+
from nucleus.validate.eval_functions.base_eval_function import (
6+
EvalFunctionConfig,
7+
)
68

79
from ..data_transfer_objects.eval_function import EvalFunctionEntry
810
from ..errors import EvalFunctionNotAvailableError
911

1012
MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes"
1113

1214

13-
class BoundingBoxIOU(BaseEvalFunction):
15+
class PolygonIOUConfig(EvalFunctionConfig):
16+
def __call__(
17+
self,
18+
enforce_label_match: bool = False,
19+
iou_threshold: float = 0.0,
20+
confidence_threshold: float = 0.0,
21+
**kwargs,
22+
):
23+
"""Configures a call to :class:`PolygonIOU` object.
24+
::
25+
26+
import nucleus
27+
28+
client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
29+
bbox_iou: BoundingBoxIOU = client.validate.eval_functions.bbox_iou
30+
slice_id = "slc_<your_slice>"
31+
scenario_test = client.validate.create_scenario_test(
32+
"Example test",
33+
slice_id=slice_id,
34+
evaluation_criteria=[bbox_iou(confidence_threshold=0.8) > 0.5]
35+
)
36+
37+
Args:
38+
enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
39+
iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
40+
confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
41+
"""
42+
return super().__call__(
43+
enforce_label_match=enforce_label_match,
44+
iou_threshold=iou_threshold,
45+
confidence_threshold=confidence_threshold,
46+
**kwargs,
47+
)
48+
1449
@classmethod
1550
def expected_name(cls) -> str:
1651
return "bbox_iou"
1752

1853

19-
class BoundingBoxMeanAveragePrecision(BaseEvalFunction):
54+
class PolygonMAPConfig(EvalFunctionConfig):
55+
def __call__(
56+
self,
57+
iou_threshold: float = 0.5,
58+
**kwargs,
59+
):
60+
"""Configures a call to :class:`PolygonMAP` object.
61+
::
62+
63+
import nucleus
64+
65+
client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
66+
bbox_map: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_map
67+
slice_id = "slc_<your_slice>"
68+
scenario_test = client.validate.create_scenario_test(
69+
"Example test",
70+
slice_id=slice_id,
71+
evaluation_criteria=[bbox_map(iou_threshold=0.6) > 0.8]
72+
)
73+
74+
Args:
75+
iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
76+
"""
77+
return super().__call__(
78+
iou_threshold=iou_threshold,
79+
**kwargs,
80+
)
81+
2082
@classmethod
2183
def expected_name(cls) -> str:
2284
return "bbox_map"
2385

2486

25-
class BoundingBoxRecall(BaseEvalFunction):
87+
class PolygonRecallConfig(EvalFunctionConfig):
88+
def __call__(
89+
self,
90+
enforce_label_match: bool = False,
91+
iou_threshold: float = 0.5,
92+
confidence_threshold: float = 0.0,
93+
**kwargs,
94+
):
95+
"""Configures a call to :class:`PolygonRecall` object.
96+
::
97+
98+
import nucleus
99+
100+
client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
101+
bbox_recall: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_recall
102+
slice_id = "slc_<your_slice>"
103+
scenario_test = client.validate.create_scenario_test(
104+
"Example test",
105+
slice_id=slice_id,
106+
evaluation_criteria=[bbox_recall(iou_threshold=0.6, confidence_threshold=0.4) > 0.9]
107+
)
108+
109+
Args:
110+
enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
111+
iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
112+
confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
113+
"""
114+
return super().__call__(
115+
enforce_label_match=enforce_label_match,
116+
iou_threshold=iou_threshold,
117+
confidence_threshold=confidence_threshold,
118+
**kwargs,
119+
)
120+
26121
@classmethod
27122
def expected_name(cls) -> str:
28123
return "bbox_recall"
29124

30125

31-
class BoundingBoxPrecision(BaseEvalFunction):
126+
class PolygonPrecisionConfig(EvalFunctionConfig):
127+
def __call__(
128+
self,
129+
enforce_label_match: bool = False,
130+
iou_threshold: float = 0.5,
131+
confidence_threshold: float = 0.0,
132+
**kwargs,
133+
):
134+
"""Configures a call to :class:`PolygonPrecision` object.
135+
::
136+
137+
import nucleus
138+
139+
client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
140+
bbox_precision: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_precision
141+
slice_id = "slc_<your_slice>"
142+
scenario_test = client.validate.create_scenario_test(
143+
"Example test",
144+
slice_id=slice_id,
145+
evaluation_criteria=[bbox_precision(iou_threshold=0.6, confidence_threshold=0.4) > 0.9]
146+
)
147+
148+
Args:
149+
enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
150+
iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
151+
confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
152+
"""
153+
return super().__call__(
154+
enforce_label_match=enforce_label_match,
155+
iou_threshold=iou_threshold,
156+
confidence_threshold=confidence_threshold,
157+
**kwargs,
158+
)
159+
32160
@classmethod
33161
def expected_name(cls) -> str:
34162
return "bbox_precision"
35163

36164

37-
class CategorizationF1(BaseEvalFunction):
165+
class CategorizationF1Config(EvalFunctionConfig):
166+
def __call__(
167+
self,
168+
confidence_threshold: Optional[float] = None,
169+
f1_method: Optional[str] = None,
170+
**kwargs,
171+
):
172+
""" Configure an evaluation of :class:`CategorizationF1`.
173+
::
174+
175+
import nucleus
176+
177+
client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
178+
cat_f1: CategorizationF1 = client.validate.eval_functions.cat_f1
179+
slice_id = "slc_<your_slice>"
180+
scenario_test = client.validate.create_scenario_test(
181+
"Example test",
182+
slice_id=slice_id,
183+
evaluation_criteria=[cat_f1(confidence_threshold=0.6, f1_method="weighted") > 0.7]
184+
)
185+
186+
Args:
187+
confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation.
188+
Must be in [0, 1]. Default 0.0
189+
f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \
190+
default='macro'
191+
This parameter is required for multiclass/multilabel targets.
192+
If ``None``, the scores for each class are returned. Otherwise, this
193+
determines the type of averaging performed on the data:
194+
195+
``'binary'``:
196+
Only report results for the class specified by ``pos_label``.
197+
This is applicable only if targets (``y_{true,pred}``) are binary.
198+
``'micro'``:
199+
Calculate metrics globally by counting the total true positives,
200+
false negatives and false positives.
201+
``'macro'``:
202+
Calculate metrics for each label, and find their unweighted
203+
mean. This does not take label imbalance into account.
204+
``'weighted'``:
205+
Calculate metrics for each label, and find their average weighted
206+
by support (the number of true instances for each label). This
207+
alters 'macro' to account for label imbalance; it can result in an
208+
F-score that is not between precision and recall.
209+
``'samples'``:
210+
Calculate metrics for each instance, and find their average (only
211+
meaningful for multilabel classification where this differs from
212+
:func:`accuracy_score`).
213+
"""
214+
return super().__call__(
215+
confidence_threshold=confidence_threshold, f1_method=f1_method
216+
)
217+
38218
@classmethod
39219
def expected_name(cls) -> str:
40220
return "cat_f1"
41221

42222

43-
class CustomEvalFunction(BaseEvalFunction):
223+
class CustomEvalFunction(EvalFunctionConfig):
44224
@classmethod
45225
def expected_name(cls) -> str:
46226
raise NotImplementedError(
47227
"Custm evaluation functions are coming soon"
48228
) # Placeholder: See super().eval_func_entry for actual name
49229

50230

51-
class StandardEvalFunction(BaseEvalFunction):
231+
class StandardEvalFunction(EvalFunctionConfig):
52232
"""Class for standard Model CI eval functions that have not been added as attributes on
53233
AvailableEvalFunctions yet.
54234
"""
@@ -65,7 +245,7 @@ def expected_name(cls) -> str:
65245
return "public_function" # Placeholder: See super().eval_func_entry for actual name
66246

67247

68-
class EvalFunctionNotAvailable(BaseEvalFunction):
248+
class EvalFunctionNotAvailable(EvalFunctionConfig):
69249
def __init__(
70250
self, not_available_name: str
71251
): # pylint: disable=super-init-not-called
@@ -89,13 +269,14 @@ def expected_name(cls) -> str:
89269

90270

91271
EvalFunction = Union[
92-
Type[BoundingBoxIOU],
93-
Type[BoundingBoxMeanAveragePrecision],
94-
Type[BoundingBoxPrecision],
95-
Type[BoundingBoxRecall],
96-
Type[CustomEvalFunction],
97-
Type[EvalFunctionNotAvailable],
98-
Type[StandardEvalFunction],
272+
PolygonIOUConfig,
273+
PolygonMAPConfig,
274+
PolygonPrecisionConfig,
275+
PolygonRecallConfig,
276+
CategorizationF1Config,
277+
CustomEvalFunction,
278+
EvalFunctionNotAvailable,
279+
StandardEvalFunction,
99280
]
100281

101282

@@ -124,24 +305,24 @@ def __init__(self, available_functions: List[EvalFunctionEntry]):
124305
f.name: f for f in available_functions if f.is_public
125306
}
126307
# NOTE: Public are assigned
127-
self._public_to_function: Dict[str, BaseEvalFunction] = {}
308+
self._public_to_function: Dict[str, EvalFunctionConfig] = {}
128309
self._custom_to_function: Dict[str, CustomEvalFunction] = {
129310
f.name: CustomEvalFunction(f)
130311
for f in available_functions
131312
if not f.is_public
132313
}
133-
self.bbox_iou = self._assign_eval_function_if_defined(BoundingBoxIOU) # type: ignore
134-
self.bbox_precision = self._assign_eval_function_if_defined(
135-
BoundingBoxPrecision # type: ignore
314+
self.bbox_iou: PolygonIOUConfig = self._assign_eval_function_if_defined(PolygonIOUConfig) # type: ignore
315+
self.bbox_precision: PolygonPrecisionConfig = self._assign_eval_function_if_defined(
316+
PolygonPrecisionConfig # type: ignore
136317
)
137-
self.bbox_recall = self._assign_eval_function_if_defined(
138-
BoundingBoxRecall # type: ignore
318+
self.bbox_recall: PolygonRecallConfig = self._assign_eval_function_if_defined(
319+
PolygonRecallConfig # type: ignore
139320
)
140-
self.bbox_map = self._assign_eval_function_if_defined(
141-
BoundingBoxMeanAveragePrecision # type: ignore
321+
self.bbox_map: PolygonMAPConfig = self._assign_eval_function_if_defined(
322+
PolygonMAPConfig # type: ignore
142323
)
143-
self.cat_f1 = self._assign_eval_function_if_defined(
144-
CategorizationF1 # type: ignore
324+
self.cat_f1: CategorizationF1Config = self._assign_eval_function_if_defined(
325+
CategorizationF1Config # type: ignore
145326
)
146327

147328
# Add public entries that have not been implemented as an attribute on this class
@@ -163,7 +344,7 @@ def __repr__(self):
163344
)
164345

165346
@property
166-
def public_functions(self) -> Dict[str, BaseEvalFunction]:
347+
def public_functions(self) -> Dict[str, EvalFunctionConfig]:
167348
"""Standard functions provided by Model CI.
168349
169350
Notes:

0 commit comments

Comments
 (0)