1
1
import itertools
2
- from typing import Callable , Dict , List , Type , Union
2
+ from typing import Callable , Dict , List , Optional , Union
3
3
4
4
from nucleus .logger import logger
5
- from nucleus .validate .eval_functions .base_eval_function import BaseEvalFunction
5
+ from nucleus .validate .eval_functions .base_eval_function import (
6
+ EvalFunctionConfig ,
7
+ )
6
8
7
9
from ..data_transfer_objects .eval_function import EvalFunctionEntry
8
10
from ..errors import EvalFunctionNotAvailableError
9
11
10
12
MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes"
11
13
12
14
13
- class BoundingBoxIOU (BaseEvalFunction ):
15
+ class PolygonIOUConfig (EvalFunctionConfig ):
16
+ def __call__ (
17
+ self ,
18
+ enforce_label_match : bool = False ,
19
+ iou_threshold : float = 0.0 ,
20
+ confidence_threshold : float = 0.0 ,
21
+ ** kwargs ,
22
+ ):
23
+ """Configures a call to :class:`PolygonIOU` object.
24
+ ::
25
+
26
+ import nucleus
27
+
28
+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
29
+ bbox_iou: BoundingBoxIOU = client.validate.eval_functions.bbox_iou
30
+ slice_id = "slc_<your_slice>"
31
+ scenario_test = client.validate.create_scenario_test(
32
+ "Example test",
33
+ slice_id=slice_id,
34
+ evaluation_criteria=[bbox_iou(confidence_threshold=0.8) > 0.5]
35
+ )
36
+
37
+ Args:
38
+ enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
39
+ iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
40
+ confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
41
+ """
42
+ return super ().__call__ (
43
+ enforce_label_match = enforce_label_match ,
44
+ iou_threshold = iou_threshold ,
45
+ confidence_threshold = confidence_threshold ,
46
+ ** kwargs ,
47
+ )
48
+
14
49
@classmethod
15
50
def expected_name (cls ) -> str :
16
51
return "bbox_iou"
17
52
18
53
19
- class BoundingBoxMeanAveragePrecision (BaseEvalFunction ):
54
+ class PolygonMAPConfig (EvalFunctionConfig ):
55
+ def __call__ (
56
+ self ,
57
+ iou_threshold : float = 0.5 ,
58
+ ** kwargs ,
59
+ ):
60
+ """Configures a call to :class:`PolygonMAP` object.
61
+ ::
62
+
63
+ import nucleus
64
+
65
+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
66
+ bbox_map: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_map
67
+ slice_id = "slc_<your_slice>"
68
+ scenario_test = client.validate.create_scenario_test(
69
+ "Example test",
70
+ slice_id=slice_id,
71
+ evaluation_criteria=[bbox_map(iou_threshold=0.6) > 0.8]
72
+ )
73
+
74
+ Args:
75
+ iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
76
+ """
77
+ return super ().__call__ (
78
+ iou_threshold = iou_threshold ,
79
+ ** kwargs ,
80
+ )
81
+
20
82
@classmethod
21
83
def expected_name (cls ) -> str :
22
84
return "bbox_map"
23
85
24
86
25
- class BoundingBoxRecall (BaseEvalFunction ):
87
+ class PolygonRecallConfig (EvalFunctionConfig ):
88
+ def __call__ (
89
+ self ,
90
+ enforce_label_match : bool = False ,
91
+ iou_threshold : float = 0.5 ,
92
+ confidence_threshold : float = 0.0 ,
93
+ ** kwargs ,
94
+ ):
95
+ """Configures a call to :class:`PolygonRecall` object.
96
+ ::
97
+
98
+ import nucleus
99
+
100
+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
101
+ bbox_recall: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_recall
102
+ slice_id = "slc_<your_slice>"
103
+ scenario_test = client.validate.create_scenario_test(
104
+ "Example test",
105
+ slice_id=slice_id,
106
+ evaluation_criteria=[bbox_recall(iou_threshold=0.6, confidence_threshold=0.4) > 0.9]
107
+ )
108
+
109
+ Args:
110
+ enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
111
+ iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
112
+ confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
113
+ """
114
+ return super ().__call__ (
115
+ enforce_label_match = enforce_label_match ,
116
+ iou_threshold = iou_threshold ,
117
+ confidence_threshold = confidence_threshold ,
118
+ ** kwargs ,
119
+ )
120
+
26
121
@classmethod
27
122
def expected_name (cls ) -> str :
28
123
return "bbox_recall"
29
124
30
125
31
- class BoundingBoxPrecision (BaseEvalFunction ):
126
+ class PolygonPrecisionConfig (EvalFunctionConfig ):
127
+ def __call__ (
128
+ self ,
129
+ enforce_label_match : bool = False ,
130
+ iou_threshold : float = 0.5 ,
131
+ confidence_threshold : float = 0.0 ,
132
+ ** kwargs ,
133
+ ):
134
+ """Configures a call to :class:`PolygonPrecision` object.
135
+ ::
136
+
137
+ import nucleus
138
+
139
+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
140
+ bbox_precision: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_precision
141
+ slice_id = "slc_<your_slice>"
142
+ scenario_test = client.validate.create_scenario_test(
143
+ "Example test",
144
+ slice_id=slice_id,
145
+ evaluation_criteria=[bbox_precision(iou_threshold=0.6, confidence_threshold=0.4) > 0.9]
146
+ )
147
+
148
+ Args:
149
+ enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
150
+ iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
151
+ confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
152
+ """
153
+ return super ().__call__ (
154
+ enforce_label_match = enforce_label_match ,
155
+ iou_threshold = iou_threshold ,
156
+ confidence_threshold = confidence_threshold ,
157
+ ** kwargs ,
158
+ )
159
+
32
160
@classmethod
33
161
def expected_name (cls ) -> str :
34
162
return "bbox_precision"
35
163
36
164
37
- class CategorizationF1 (BaseEvalFunction ):
165
+ class CategorizationF1Config (EvalFunctionConfig ):
166
+ def __call__ (
167
+ self ,
168
+ confidence_threshold : Optional [float ] = None ,
169
+ f1_method : Optional [str ] = None ,
170
+ ** kwargs ,
171
+ ):
172
+ """ Configure an evaluation of :class:`CategorizationF1`.
173
+ ::
174
+
175
+ import nucleus
176
+
177
+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
178
+ cat_f1: CategorizationF1 = client.validate.eval_functions.cat_f1
179
+ slice_id = "slc_<your_slice>"
180
+ scenario_test = client.validate.create_scenario_test(
181
+ "Example test",
182
+ slice_id=slice_id,
183
+ evaluation_criteria=[cat_f1(confidence_threshold=0.6, f1_method="weighted") > 0.7]
184
+ )
185
+
186
+ Args:
187
+ confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation.
188
+ Must be in [0, 1]. Default 0.0
189
+ f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \
190
+ default='macro'
191
+ This parameter is required for multiclass/multilabel targets.
192
+ If ``None``, the scores for each class are returned. Otherwise, this
193
+ determines the type of averaging performed on the data:
194
+
195
+ ``'binary'``:
196
+ Only report results for the class specified by ``pos_label``.
197
+ This is applicable only if targets (``y_{true,pred}``) are binary.
198
+ ``'micro'``:
199
+ Calculate metrics globally by counting the total true positives,
200
+ false negatives and false positives.
201
+ ``'macro'``:
202
+ Calculate metrics for each label, and find their unweighted
203
+ mean. This does not take label imbalance into account.
204
+ ``'weighted'``:
205
+ Calculate metrics for each label, and find their average weighted
206
+ by support (the number of true instances for each label). This
207
+ alters 'macro' to account for label imbalance; it can result in an
208
+ F-score that is not between precision and recall.
209
+ ``'samples'``:
210
+ Calculate metrics for each instance, and find their average (only
211
+ meaningful for multilabel classification where this differs from
212
+ :func:`accuracy_score`).
213
+ """
214
+ return super ().__call__ (
215
+ confidence_threshold = confidence_threshold , f1_method = f1_method
216
+ )
217
+
38
218
@classmethod
39
219
def expected_name (cls ) -> str :
40
220
return "cat_f1"
41
221
42
222
43
- class CustomEvalFunction (BaseEvalFunction ):
223
+ class CustomEvalFunction (EvalFunctionConfig ):
44
224
@classmethod
45
225
def expected_name (cls ) -> str :
46
226
raise NotImplementedError (
47
227
"Custm evaluation functions are coming soon"
48
228
) # Placeholder: See super().eval_func_entry for actual name
49
229
50
230
51
- class StandardEvalFunction (BaseEvalFunction ):
231
+ class StandardEvalFunction (EvalFunctionConfig ):
52
232
"""Class for standard Model CI eval functions that have not been added as attributes on
53
233
AvailableEvalFunctions yet.
54
234
"""
@@ -65,7 +245,7 @@ def expected_name(cls) -> str:
65
245
return "public_function" # Placeholder: See super().eval_func_entry for actual name
66
246
67
247
68
- class EvalFunctionNotAvailable (BaseEvalFunction ):
248
+ class EvalFunctionNotAvailable (EvalFunctionConfig ):
69
249
def __init__ (
70
250
self , not_available_name : str
71
251
): # pylint: disable=super-init-not-called
@@ -89,13 +269,14 @@ def expected_name(cls) -> str:
89
269
90
270
91
271
EvalFunction = Union [
92
- Type [BoundingBoxIOU ],
93
- Type [BoundingBoxMeanAveragePrecision ],
94
- Type [BoundingBoxPrecision ],
95
- Type [BoundingBoxRecall ],
96
- Type [CustomEvalFunction ],
97
- Type [EvalFunctionNotAvailable ],
98
- Type [StandardEvalFunction ],
272
+ PolygonIOUConfig ,
273
+ PolygonMAPConfig ,
274
+ PolygonPrecisionConfig ,
275
+ PolygonRecallConfig ,
276
+ CategorizationF1Config ,
277
+ CustomEvalFunction ,
278
+ EvalFunctionNotAvailable ,
279
+ StandardEvalFunction ,
99
280
]
100
281
101
282
@@ -124,24 +305,24 @@ def __init__(self, available_functions: List[EvalFunctionEntry]):
124
305
f .name : f for f in available_functions if f .is_public
125
306
}
126
307
# NOTE: Public are assigned
127
- self ._public_to_function : Dict [str , BaseEvalFunction ] = {}
308
+ self ._public_to_function : Dict [str , EvalFunctionConfig ] = {}
128
309
self ._custom_to_function : Dict [str , CustomEvalFunction ] = {
129
310
f .name : CustomEvalFunction (f )
130
311
for f in available_functions
131
312
if not f .is_public
132
313
}
133
- self .bbox_iou = self ._assign_eval_function_if_defined (BoundingBoxIOU ) # type: ignore
134
- self .bbox_precision = self ._assign_eval_function_if_defined (
135
- BoundingBoxPrecision # type: ignore
314
+ self .bbox_iou : PolygonIOUConfig = self ._assign_eval_function_if_defined (PolygonIOUConfig ) # type: ignore
315
+ self .bbox_precision : PolygonPrecisionConfig = self ._assign_eval_function_if_defined (
316
+ PolygonPrecisionConfig # type: ignore
136
317
)
137
- self .bbox_recall = self ._assign_eval_function_if_defined (
138
- BoundingBoxRecall # type: ignore
318
+ self .bbox_recall : PolygonRecallConfig = self ._assign_eval_function_if_defined (
319
+ PolygonRecallConfig # type: ignore
139
320
)
140
- self .bbox_map = self ._assign_eval_function_if_defined (
141
- BoundingBoxMeanAveragePrecision # type: ignore
321
+ self .bbox_map : PolygonMAPConfig = self ._assign_eval_function_if_defined (
322
+ PolygonMAPConfig # type: ignore
142
323
)
143
- self .cat_f1 = self ._assign_eval_function_if_defined (
144
- CategorizationF1 # type: ignore
324
+ self .cat_f1 : CategorizationF1Config = self ._assign_eval_function_if_defined (
325
+ CategorizationF1Config # type: ignore
145
326
)
146
327
147
328
# Add public entries that have not been implemented as an attribute on this class
@@ -163,7 +344,7 @@ def __repr__(self):
163
344
)
164
345
165
346
@property
166
- def public_functions (self ) -> Dict [str , BaseEvalFunction ]:
347
+ def public_functions (self ) -> Dict [str , EvalFunctionConfig ]:
167
348
"""Standard functions provided by Model CI.
168
349
169
350
Notes:
0 commit comments