Skip to content

Commit 1a4a37f

Browse files
authored
Caching and Parallelism Improvements to Data Streaming (#7)
* feat: add caching and parallel fetching * cleanup * fix: support multiprocessing context * cleanup Co-authored-by: Pablo <[email protected]>
1 parent f812a39 commit 1a4a37f

File tree

6 files changed

+132
-66
lines changed

6 files changed

+132
-66
lines changed

sdk/diffgram/core/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(
5353
project_string_id = project_string_id,
5454
client_id = client_id,
5555
client_secret = client_secret)
56+
self.client_id = client_id
57+
self.client_secret = client_secret
5658
self.file = FileConstructor(self)
5759
self.train = Train(self)
5860
self.job = Job(self)

sdk/diffgram/core/diffgram_dataset_iterator.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
11
from PIL import Image, ImageDraw
22
from imageio import imread
33
import numpy as np
4+
import traceback
5+
import sys
6+
from threading import Thread
7+
from concurrent.futures import ThreadPoolExecutor
8+
49

510
class DiffgramDatasetIterator:
611

7-
def __init__(self, project, diffgram_file_id_list, validate_ids = True):
12+
def __init__(self, project,
13+
diffgram_file_id_list,
14+
validate_ids = True,
15+
max_size_cache = 1073741824,
16+
max_num_concurrent_fetches = 25):
817
"""
918
1019
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
1120
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
1221
"""
1322
self.diffgram_file_id_list = diffgram_file_id_list
14-
23+
self.max_size_cache = max_size_cache
24+
self.pool = ThreadPoolExecutor(max_num_concurrent_fetches)
1525
self.project = project
26+
self.file_cache = {}
1627
self._internal_file_list = []
1728
if validate_ids:
1829
self.__validate_file_ids()
@@ -25,22 +36,58 @@ def __iter__(self):
2536
def __len__(self):
2637
return len(self.diffgram_file_id_list)
2738

28-
def __getitem__(self, idx):
29-
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True)
39+
def save_file_in_cache(self, idx, instance_data):
40+
# If size of cache greater than 1GB (Default)
41+
if sys.getsizeof(self.file_cache) > self.max_size_cache:
42+
keys = list(self.file_cache.keys())
43+
latest_keys = keys[:-10] # Get oldest 10 elements
44+
for k in latest_keys:
45+
self.file_cache.pop(k)
46+
47+
self.file_cache[idx] = instance_data
48+
49+
def get_next_n_items(self, idx, num_items = 25):
50+
"""
51+
Get next N items and save them to cache proactively.
52+
:param idx:
53+
:param n:
54+
:return:
55+
"""
56+
latest_index = idx + num_items
57+
if latest_index >= len(self.diffgram_file_id_list):
58+
latest_index = len(self.diffgram_file_id_list)
59+
60+
for i in range(idx + 1, latest_index):
61+
self.pool.submit(self.__get_file_data_for_index, (i,))
62+
return True
63+
64+
def __get_file_data_for_index(self, idx):
65+
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True, use_session = False)
3066
instance_data = self.get_file_instances(diffgram_file)
67+
self.save_file_in_cache(idx, instance_data)
3168
return instance_data
3269

70+
def __getitem__(self, idx):
71+
if self.file_cache.get(idx):
72+
return self.file_cache.get(idx)
73+
74+
result = self.__get_file_data_for_index(idx)
75+
76+
self.get_next_n_items(idx, num_items = 25)
77+
78+
return result
79+
3380
def __next__(self):
34-
file_id = self.diffgram_file_id_list[self.current_file_index]
35-
diffgram_file = self.project.file.get_by_id(file_id, with_instances = True)
36-
instance_data = self.get_file_instances(diffgram_file)
81+
if self.file_cache.get(self.current_file_index):
82+
return self.file_cache.get(self.current_file_index)
83+
instance_data = self.__get_file_data_for_index(self.current_file_index)
3784
self.current_file_index += 1
3885
return instance_data
3986

4087
def __validate_file_ids(self):
4188
if not self.diffgram_file_id_list:
4289
return
43-
result = self.project.file.file_list_exists(self.diffgram_file_id_list)
90+
result = self.project.file.file_list_exists(self.diffgram_file_id_list, use_session = False)
4491
if not result:
4592
raise Exception(
4693
'Some file IDs do not belong to the project. Please provide only files from the same project.')
@@ -56,7 +103,9 @@ def get_image_data(self, diffgram_file):
56103
if i < MAX_RETRIES - 1:
57104
continue
58105
else:
59-
raise e
106+
print('Fetch Image Failed: Diffgram File ID: {}'.format(diffgram_file.id))
107+
print(traceback.format_exc())
108+
return None
60109
return image
61110
else:
62111
raise Exception('Pytorch datasets only support images. Please provide only file_ids from images')

sdk/diffgram/file/file_constructor.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from diffgram.job.job import Job
55
import json
66
import os
7-
7+
import requests
8+
from requests.auth import HTTPDigestAuth
89

910
class FileConstructor():
1011
"""
@@ -401,7 +402,7 @@ def get_file_list(self, id_list: list, with_instances: bool = False):
401402

402403
raise NotImplementedError
403404

404-
def file_list_exists(self, id_list):
405+
def file_list_exists(self, id_list, use_session = True):
405406
"""
406407
Verifies that the given ID list exists inside the project.
407408
:param id_list:
@@ -413,10 +414,16 @@ def file_list_exists(self, id_list):
413414
spec_dict = {
414415
'file_id_list': id_list
415416
}
416-
response = self.client.session.post(
417-
self.client.host + url,
418-
json = spec_dict)
419-
417+
if use_session:
418+
response = self.client.session.post(
419+
self.client.host + url,
420+
json = spec_dict)
421+
else:
422+
response = requests.post(
423+
url = self.client.host + url,
424+
json = spec_dict,
425+
auth = HTTPDigestAuth(self.client.client_id, self.client.client_secret)
426+
)
420427
self.client.handle_errors(response)
421428

422429
response_json = response.json()
@@ -428,7 +435,8 @@ def file_list_exists(self, id_list):
428435

429436
def get_by_id(self,
430437
id: int,
431-
with_instances: bool = False):
438+
with_instances: bool = False,
439+
use_session = True):
432440
"""
433441
returns Diffgram File object
434442
"""
@@ -450,9 +458,15 @@ def get_by_id(self,
450458
}
451459
file_response_key = 'file_serialized'
452460

453-
response = self.client.session.post(
454-
self.client.host + endpoint,
455-
json = spec_dict)
461+
if use_session:
462+
response = self.client.session.post(
463+
self.client.host + endpoint,
464+
json = spec_dict)
465+
else:
466+
# Add Auth
467+
response = requests.post(self.client.host + endpoint,
468+
json = spec_dict,
469+
auth = HTTPDigestAuth(self.client.client_id, self.client.client_secret))
456470

457471
self.client.handle_errors(response)
458472

sdk/diffgram/file/view.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,60 @@
1+
import requests
2+
from requests.auth import HTTPDigestAuth
13

24

35
def get_file_id():
4-
"""
5-
Get Project file id
6+
"""
7+
Get Project file id
68
7-
Arguments
8-
project string id
9-
working directory?
10-
filename??
9+
Arguments
10+
project string id
11+
working directory?
12+
filename??
1113
12-
Future
13-
How are we handling video with this?
14-
API method for this?
14+
Future
15+
How are we handling video with this?
16+
API method for this?
1517
16-
"""
17-
pass
18+
"""
19+
pass
1820

1921

22+
def get_label_file_dict(self, use_session = True):
23+
"""
24+
Get Project label file id dict for project
2025
21-
def get_label_file_dict(self):
22-
"""
23-
Get Project label file id dict for project
26+
Arguments
27+
self
2428
25-
Arguments
26-
self
27-
28-
Expects
29-
self.project_string_id
30-
self.directory_id
29+
Expects
30+
self.project_string_id
31+
self.directory_id
3132
32-
Returns
33-
sets self.name_to_file_id to the dict returned
33+
Returns
34+
sets self.name_to_file_id to the dict returned
3435
35-
"""
36-
if self.project_string_id is None:
37-
raise Exception("No project string." + \
38-
"Set a project string using .auth()")
36+
"""
37+
if self.project_string_id is None:
38+
raise Exception("No project string." + \
39+
"Set a project string using .auth()")
3940

40-
if type(self.project_string_id) != str:
41-
raise Exception("project_string_id must be of type String")
41+
if type(self.project_string_id) != str:
42+
raise Exception("project_string_id must be of type String")
4243

43-
endpoint = "/api/v1/project/" + self.project_string_id + \
44-
"/labels/view/name_to_file_id"
44+
endpoint = "/api/v1/project/" + self.project_string_id + \
45+
"/labels/view/name_to_file_id"
46+
if use_session:
47+
response = self.session.get(self.host + endpoint)
48+
else:
49+
# Add Auth
50+
response = requests.get(self.host + endpoint,
51+
headers = {'directory_id': str(self.directory_id)},
52+
auth = HTTPDigestAuth(self.client_id, self.client_secret))
4553

46-
response = self.session.get(self.host + endpoint)
47-
48-
self.handle_errors(response)
49-
50-
data = response.json()
51-
52-
if data["log"]["success"] == True:
53-
self.name_to_file_id = data["name_to_file_id"]
54-
else:
55-
raise Exception(data["log"]["errors"])
54+
self.handle_errors(response)
5655

56+
data = response.json()
57+
if data["log"]["success"] == True:
58+
self.name_to_file_id = data["name_to_file_id"]
59+
else:
60+
raise Exception(data["log"]["errors"])

sdk/diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def __get_next_page_of_data(self):
2828
def __getitem__(self, idx):
2929
if torch.is_tensor(idx):
3030
idx = idx.tolist()
31-
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True)
3231

33-
sample = self.get_file_instances(diffgram_file)
32+
sample = super().__getitem__(idx)
33+
3434
if 'x_min_list' in sample:
3535
sample['x_min_list'] = torch.Tensor(sample['x_min_list'])
3636
if 'x_max_list' in sample:

sdk/diffgram/tensorflow_diffgram/diffgram_tensorflow_dataset.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,15 @@ def __getitem__(self, idx):
5454
return tf_example
5555

5656
def get_tf_train_example(self, idx):
57-
file_id = self.diffgram_file_id_list[idx]
58-
diffgram_file = self.project.file.get_by_id(file_id, with_instances = True)
59-
image = self.get_image_data(diffgram_file)
60-
instance_data = self.get_file_instances(diffgram_file)
57+
instance_data = super().__getitem__(idx)
6158
filename, file_extension = os.path.splitext(instance_data['diffgram_file'].image['original_filename'])
6259
label_names_bytes = [x.encode() for x in instance_data['label_name_list']]
6360
tf_example_dict = {
6461
'image/height': self.int64_feature(instance_data['diffgram_file'].image['height']),
6562
'image/width': self.int64_feature(instance_data['diffgram_file'].image['width']),
6663
'image/filename': self.bytes_feature(filename.encode()),
6764
'image/source_id': self.bytes_feature(filename.encode()),
68-
'image/encoded': self.bytes_feature(image.tobytes()),
65+
'image/encoded': self.bytes_feature(instance_data['image'].tobytes()),
6966
'image/format': self.bytes_feature(file_extension.encode()),
7067
'image/object/bbox/xmin': self.float_list_feature(instance_data['x_min_list']),
7168
'image/object/bbox/xmax': self.float_list_feature(instance_data['x_max_list']),

0 commit comments

Comments
 (0)