Skip to content

Commit 2d52a4f

Browse files
committed
make annotation requests in batches
1 parent 7363a0f commit 2d52a4f

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
REFERENCE_ID_KEY = "reference_id"
103103
BACKEND_REFERENCE_ID_KEY = "ref_id" # TODO(355762): Our backend returns this instead of the "proper" key sometimes.
104104
REQUEST_ID_KEY = "requestId"
105+
REQUEST_IDS_KEY = "requestIds"
105106
SCENES_KEY = "scenes"
106107
SERIALIZED_REQUEST_KEY = "serialized_request"
107108
SEGMENTATIONS_KEY = "segmentations"

nucleus/dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
format_prediction_response,
2121
paginate_generator,
2222
serialize_and_write_to_presigned_url,
23+
serialize_and_write_to_presigned_urls_in_batches,
2324
)
2425

2526
from .annotation import Annotation, check_all_mask_paths_remote
@@ -39,6 +40,7 @@
3940
NAME_KEY,
4041
REFERENCE_IDS_KEY,
4142
REQUEST_ID_KEY,
43+
REQUEST_IDS_KEY,
4244
SLICE_ID_KEY,
4345
UPDATE_KEY,
4446
VIDEO_UPLOAD_TYPE_KEY,
@@ -390,11 +392,11 @@ def annotate(
390392
"""
391393
if asynchronous:
392394
check_all_mask_paths_remote(annotations)
393-
request_id = serialize_and_write_to_presigned_url(
395+
request_ids = serialize_and_write_to_presigned_urls_in_batches(
394396
annotations, self.id, self._client
395397
)
396398
response = self._client.make_request(
397-
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
399+
payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update},
398400
route=f"dataset/{self.id}/annotate?async=1",
399401
)
400402
return AsyncJob.from_json(response, self._client)
@@ -1384,11 +1386,11 @@ def upload_predictions(
13841386
if asynchronous:
13851387
check_all_mask_paths_remote(predictions)
13861388

1387-
request_id = serialize_and_write_to_presigned_url(
1389+
request_ids = serialize_and_write_to_presigned_urls_in_batches(
13881390
predictions, self.id, self._client
13891391
)
13901392
response = self._client.make_request(
1391-
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
1393+
payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update},
13921394
route=f"dataset/{self.id}/model/{model.id}/uploadPredictions?async=1",
13931395
)
13941396
return AsyncJob.from_json(response, self._client)

nucleus/model_run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from nucleus.job import AsyncJob
2323
from nucleus.utils import (
2424
format_prediction_response,
25-
serialize_and_write_to_presigned_url,
25+
serialize_and_write_to_presigned_urls_in_batches,
2626
)
2727

2828
from .constants import (
2929
ANNOTATIONS_KEY,
3030
DEFAULT_ANNOTATION_UPDATE_MODE,
31-
REQUEST_ID_KEY,
31+
REQUEST_IDS_KEY,
3232
UPDATE_KEY,
3333
)
3434
from .prediction import (
@@ -157,11 +157,11 @@ def predict(
157157
if asynchronous:
158158
check_all_mask_paths_remote(annotations)
159159

160-
request_id = serialize_and_write_to_presigned_url(
160+
request_ids = serialize_and_write_to_presigned_urls_in_batches(
161161
annotations, self.dataset_id, self._client
162162
)
163163
response = self._client.make_request(
164-
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
164+
payload={REQUEST_IDS_KEY: request_ids, UPDATE_KEY: update},
165165
route=f"modelRun/{self.model_run_id}/predict?async=1",
166166
)
167167
return AsyncJob.from_json(response, self._client)

nucleus/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,34 @@ def upload_to_presigned_url(presigned_url: str, file_pointer: IO):
273273
)
274274

275275

276+
def serialize_and_write_to_presigned_urls_in_batches(
277+
upload_units: Sequence[
278+
Union[DatasetItem, Annotation, LidarScene, VideoScene]
279+
],
280+
dataset_id: str,
281+
client,
282+
batch_size: int = 10000,
283+
):
284+
"""This helper function can be used to serialize a list of API objects to batches of NDJSON files."""
285+
request_ids = []
286+
for i in range(0, len(upload_units), batch_size):
287+
upload_units_chunk = upload_units[i : i + batch_size]
288+
request_id = uuid.uuid4().hex
289+
response = client.make_request(
290+
payload={},
291+
route=f"dataset/{dataset_id}/signedUrl/{request_id}",
292+
requests_command=requests.get,
293+
)
294+
295+
strio = io.StringIO()
296+
serialize_and_write(upload_units_chunk, strio)
297+
strio.seek(0)
298+
upload_to_presigned_url(response["signed_url"], strio)
299+
300+
request_ids.append(request_id)
301+
return request_ids
302+
303+
276304
def serialize_and_write_to_presigned_url(
277305
upload_units: Sequence[
278306
Union[DatasetItem, Annotation, LidarScene, VideoScene]

0 commit comments

Comments
 (0)