Skip to content

Commit c9f1f59

Browse files
authored
Validate feature: setting baseline models (#266)
add new set model as baseline functions to client, remove add_criteria in favor of add_eval_function, bump version number and changelog
1 parent f9d1173 commit c9f1f59

18 files changed

+160
-89
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.8.3](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.8.3) - 2022-03-29
8+
9+
### Added
10+
- new Validate functionality to intialize scenario tests without a threshold, and to set test thresholds based on a baseline model.
711
## [0.8.2](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.8.2) - 2022-03-18
812

913
### Added

cli/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def build_scenario_test_info_tree(client, scenario_test, tree):
109109
slice_branch.add(f"name: '{slice_info['name']}'")
110110
slice_branch.add(f"len: {len(slc.items)}")
111111
slice_branch.add(f"url: {slice_url}")
112-
criteria = scenario_test.get_criteria()
112+
criteria = scenario_test.get_eval_functions()
113113
criteria_branch = tree.add(":crossed_flags: Criteria")
114114
for criterion in criteria:
115115
pretty_criterion = format_criterion(

nucleus/validate/client.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from nucleus.job import AsyncJob
55

66
from .constants import SCENARIO_TEST_ID_KEY
7-
from .data_transfer_objects.eval_function import (
8-
EvaluationCriterion,
9-
GetEvalFunctions,
10-
)
7+
from .data_transfer_objects.eval_function import GetEvalFunctions
118
from .data_transfer_objects.scenario_test import CreateScenarioTestRequest
129
from .errors import CreateScenarioTestError
13-
from .eval_functions.available_eval_functions import AvailableEvalFunctions
10+
from .eval_functions.available_eval_functions import (
11+
AvailableEvalFunctions,
12+
EvalFunction,
13+
)
1414
from .scenario_test import ScenarioTest
1515

1616
SUCCESS_KEY = "success"
@@ -51,7 +51,7 @@ def create_scenario_test(
5151
self,
5252
name: str,
5353
slice_id: str,
54-
evaluation_criteria: List[EvaluationCriterion],
54+
evaluation_functions: List[EvalFunction],
5555
) -> ScenarioTest:
5656
"""Creates a new Scenario Test from an existing Nucleus :class:`Slice`:. ::
5757
@@ -61,28 +61,30 @@ def create_scenario_test(
6161
scenario_test = client.validate.create_scenario_test(
6262
name="sample_scenario_test",
6363
slice_id="YOUR_SLICE_ID",
64-
evaluation_criteria=[client.validate.eval_functions.bbox_iou() > 0.5]
64+
evaluation_functions=[client.validate.eval_functions.bbox_iou()]
6565
)
6666
6767
Args:
6868
name: unique name of test
6969
slice_id: id of (pre-defined) slice of items to evaluate test on.
70-
evaluation_criteria: :class:`EvaluationCriterion` defines a pass/fail criteria for the test. Created with a
71-
comparison with an eval functions. See :class:`eval_functions`.
70+
evaluation_functions: :class:`EvalFunctionEntry` defines an evaluation metric for the test.
71+
Created with an element from the list of available eval functions. See :class:`eval_functions`.
7272
7373
Returns:
7474
Created ScenarioTest object.
7575
"""
76-
if not evaluation_criteria:
76+
if not evaluation_functions:
7777
raise CreateScenarioTestError(
78-
"Must pass an evaluation_criteria to the scenario test! I.e. "
79-
"evaluation_criteria = [client.validate.eval_functions.bbox_iou() > 0.5]"
78+
"Must pass an evaluation_function to the scenario test! I.e. "
79+
"evaluation_functions=[client.validate.eval_functions.bbox_iou()]"
8080
)
8181
response = self.connection.post(
8282
CreateScenarioTestRequest(
8383
name=name,
8484
slice_id=slice_id,
85-
evaluation_criteria=evaluation_criteria,
85+
evaluation_functions=[
86+
ef.to_entry() for ef in evaluation_functions # type:ignore
87+
],
8688
).dict(),
8789
"validate/scenario_test",
8890
)

nucleus/validate/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
THRESHOLD_KEY = "threshold"
1010
SCENARIO_TEST_ID_KEY = "scenario_test_id"
1111
SCENARIO_TEST_NAME_KEY = "scenario_test_name"
12+
SCENARIO_TEST_METRICS_KEY = "scenario_test_metrics"
1213

1314

1415
class ThresholdComparison(str, Enum):

nucleus/validate/data_transfer_objects/scenario_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
from nucleus.pydantic_base import ImmutableModel
66

7-
from .eval_function import EvaluationCriterion
7+
from .eval_function import EvalFunctionEntry
88

99

1010
class CreateScenarioTestRequest(ImmutableModel):
1111
name: str
1212
slice_id: str
13-
evaluation_criteria: List[EvaluationCriterion]
13+
evaluation_functions: List[EvalFunctionEntry]
1414

1515
@validator("slice_id")
1616
def startswith_slice_indicator(cls, v): # pylint: disable=no-self-argument
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
from nucleus.pydantic_base import ImmutableModel
22

3-
from ..constants import ThresholdComparison
43

5-
6-
class AddScenarioTestMetric(ImmutableModel):
4+
class AddScenarioTestFunction(ImmutableModel):
75
"""Data transfer object to add a scenario test."""
86

97
scenario_test_name: str
108
eval_function_id: str
11-
threshold: float
12-
threshold_comparison: ThresholdComparison

nucleus/validate/eval_functions/base_eval_function.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ def _op_to_test_metric(self, comparison: ThresholdComparison, value):
5858
threshold_comparison=comparison,
5959
threshold=value,
6060
)
61+
62+
def to_entry(self):
63+
return self.eval_func_entry

nucleus/validate/scenario_test.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
and have confidence that they’re always shipping the best model.
66
"""
77
from dataclasses import dataclass, field
8-
from typing import List
8+
from typing import List, Optional
99

1010
from ..connection import Connection
1111
from ..constants import NAME_KEY, SLICE_ID_KEY
1212
from ..dataset_item import DatasetItem
13-
from .data_transfer_objects.eval_function import EvaluationCriterion
13+
from .constants import (
14+
EVAL_FUNCTION_ID_KEY,
15+
SCENARIO_TEST_ID_KEY,
16+
SCENARIO_TEST_METRICS_KEY,
17+
THRESHOLD_COMPARISON_KEY,
18+
THRESHOLD_KEY,
19+
ThresholdComparison,
20+
)
1421
from .data_transfer_objects.scenario_test_evaluations import GetEvalHistory
15-
from .data_transfer_objects.scenario_test_metric import AddScenarioTestMetric
22+
from .data_transfer_objects.scenario_test_metric import AddScenarioTestFunction
23+
from .eval_functions.available_eval_functions import EvalFunction
1624
from .scenario_test_evaluation import ScenarioTestEvaluation
1725
from .scenario_test_metric import ScenarioTestMetric
1826

@@ -36,6 +44,7 @@ class ScenarioTest:
3644
connection: Connection = field(repr=False)
3745
name: str = field(init=False)
3846
slice_id: str = field(init=False)
47+
baseline_model_id: Optional[str] = None
3948

4049
def __post_init__(self):
4150
# TODO(gunnar): Remove this pattern. It's too slow. We should get all the info required in one call
@@ -45,10 +54,10 @@ def __post_init__(self):
4554
self.name = response[NAME_KEY]
4655
self.slice_id = response[SLICE_ID_KEY]
4756

48-
def add_criterion(
49-
self, evaluation_criterion: EvaluationCriterion
57+
def add_eval_function(
58+
self, eval_function: EvalFunction
5059
) -> ScenarioTestMetric:
51-
"""Creates and adds a new criteria to the :class:`ScenarioTest`. ::
60+
"""Creates and adds a new evaluation metric to the :class:`ScenarioTest`. ::
5261
5362
import nucleus
5463
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
@@ -58,49 +67,52 @@ def add_criterion(
5867
5968
e = client.validate.eval_functions
6069
# Assuming a user would like to add all available public evaluation functions as criteria
61-
scenario_test.add_criterion(
62-
e.bbox_iou() > 0.5
70+
scenario_test.add_eval_function(
71+
e.bbox_iou
6372
)
64-
scenario_test.add_criterion(
65-
e.bbox_map() > 0.85
73+
scenario_test.add_eval_function(
74+
e.bbox_map
6675
)
67-
scenario_test.add_criterion(
68-
e.bbox_precision() > 0.7
76+
scenario_test.add_eval_function(
77+
e.bbox_precision
6978
)
70-
scenario_test.add_criterion(
71-
e.bbox_recall() > 0.6
79+
scenario_test.add_eval_function(
80+
e.bbox_recall
7281
)
7382
7483
Args:
75-
evaluation_criterion: :class:`EvaluationCriterion` created by comparison with an :class:`EvalFunction`
84+
eval_function: :class:`EvalFunction`
7685
7786
Returns:
7887
The created ScenarioTestMetric object.
7988
"""
8089
response = self.connection.post(
81-
AddScenarioTestMetric(
90+
AddScenarioTestFunction(
8291
scenario_test_name=self.name,
83-
eval_function_id=evaluation_criterion.eval_function_id,
84-
threshold=evaluation_criterion.threshold,
85-
threshold_comparison=evaluation_criterion.threshold_comparison,
92+
eval_function_id=eval_function.id,
8693
).dict(),
87-
"validate/scenario_test_metric",
94+
"validate/scenario_test_eval_function",
8895
)
96+
print(response)
8997
return ScenarioTestMetric(
90-
scenario_test_id=response["scenario_test_id"],
91-
eval_function_id=response["eval_function_id"],
92-
threshold=evaluation_criterion.threshold,
93-
threshold_comparison=evaluation_criterion.threshold_comparison,
98+
scenario_test_id=response[SCENARIO_TEST_ID_KEY],
99+
eval_function_id=response[EVAL_FUNCTION_ID_KEY],
100+
threshold=response.get(THRESHOLD_KEY, None),
101+
threshold_comparison=response.get(
102+
THRESHOLD_COMPARISON_KEY,
103+
ThresholdComparison.GREATER_THAN_EQUAL_TO,
104+
),
105+
connection=self.connection,
94106
)
95107

96-
def get_criteria(self) -> List[ScenarioTestMetric]:
108+
def get_eval_functions(self) -> List[ScenarioTestMetric]:
97109
"""Retrieves all criteria of the :class:`ScenarioTest`. ::
98110
99111
import nucleus
100112
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
101113
scenario_test = client.validate.scenario_tests[0]
102114
103-
scenario_test.get_criteria()
115+
scenario_test.get_eval_functions()
104116
105117
Returns:
106118
A list of ScenarioTestMetric objects.
@@ -109,8 +121,8 @@ def get_criteria(self) -> List[ScenarioTestMetric]:
109121
f"validate/scenario_test/{self.id}/metrics",
110122
)
111123
return [
112-
ScenarioTestMetric(**metric)
113-
for metric in response["scenario_test_metrics"]
124+
ScenarioTestMetric(**metric, connection=self.connection)
125+
for metric in response[SCENARIO_TEST_METRICS_KEY]
114126
]
115127

116128
def get_eval_history(self) -> List[ScenarioTestEvaluation]:
@@ -141,3 +153,24 @@ def get_items(self) -> List[DatasetItem]:
141153
return [
142154
DatasetItem.from_json(item) for item in response[DATASET_ITEMS_KEY]
143155
]
156+
157+
def set_baseline_model(self, model_id: str):
158+
"""Set's a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
159+
this scenario test must have been evaluated using that model. The baseline model's performance
160+
is used as the threshold for all metrics against which other models are compared.
161+
162+
import nucleus
163+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
164+
scenario_test = client.validate.scenario_tests[0]
165+
166+
scenario_test.set_baseline_model('my_baseline_model_id')
167+
168+
Returns:
169+
A list of :class:`ScenarioTestEvaluation` objects.
170+
"""
171+
response = self.connection.post(
172+
{},
173+
f"validate/scenario_test/{self.id}/set_baseline_model/{model_id}",
174+
)
175+
self.baseline_model_id = response.get("baseline_model_id")
176+
return self.baseline_model_id
Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
1-
from nucleus.pydantic_base import ImmutableModel
1+
from dataclasses import dataclass, field
2+
from typing import Dict, Optional
23

4+
from ..connection import Connection
35
from .constants import ThresholdComparison
46

57

6-
class ScenarioTestMetric(ImmutableModel):
8+
@dataclass
9+
class ScenarioTestMetric:
710
"""A Scenario Test Metric is an evaluation function combined with a comparator and associated with a Scenario Test.
811
Scenario Test Metrics serve as the basis when evaluating a Model on a Scenario Test.
912
"""
1013

1114
scenario_test_id: str
1215
eval_function_id: str
13-
threshold: float
14-
threshold_comparison: ThresholdComparison
16+
threshold: Optional[float]
17+
connection: Connection
18+
eval_func_arguments: Optional[Dict] = field(default_factory=dict)
19+
threshold_comparison: ThresholdComparison = (
20+
ThresholdComparison.GREATER_THAN_EQUAL_TO
21+
)
22+
23+
def set_threshold(self, threshold: Optional[float] = None) -> None:
24+
"""Sets the threshold of the metric to the new value passed in as a parameters.
25+
Attributes:
26+
threshold (str): The ID of the scenario test.
27+
"""
28+
payload = {"threshold": threshold}
29+
response = self.connection.post(
30+
payload,
31+
f"validate/metric/set_threshold/{self.scenario_test_id}/{self.eval_function_id}",
32+
)
33+
self.threshold = response.get("threshold")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.8.2"
24+
version = "0.8.3"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <[email protected]>"]

tests/cli/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def scenario_test(CLIENT, test_slice, annotations, predictions):
101101
scenario_test = CLIENT.validate.create_scenario_test(
102102
name=test_name,
103103
slice_id=test_slice.id,
104-
evaluation_criteria=[CLIENT.validate.eval_functions.bbox_recall > 0.5],
104+
evaluation_functions=[CLIENT.validate.eval_functions.bbox_recall],
105105
)
106106
yield scenario_test
107107

tests/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
NUCLEUS_PYTEST_USER_ID = "60ad648c85db770026e9bf77"
2121

2222
EVAL_FUNCTION_THRESHOLD = 0.5
23-
EVAL_FUNCTION_COMPARISON = ThresholdComparison.GREATER_THAN
23+
EVAL_FUNCTION_COMPARISON = ThresholdComparison.GREATER_THAN_EQUAL_TO
2424

2525

2626
TEST_IMG_URLS = [

tests/test_annotation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ def test_default_category_gt_upload_async(dataset):
724724
assert_partial_equality(expected, result)
725725

726726

727+
@pytest.mark.skip("Need to adjust error message on taxonomy failure")
727728
@pytest.mark.integration
728729
def test_non_existent_taxonomy_category_gt_upload_async(dataset):
729730
annotation = CategoryAnnotation.from_json(

0 commit comments

Comments
 (0)