Skip to content

Commit d3bdf34

Browse files
VitalyFedyuninfacebook-github-bot
authored andcommitted
Introducing DataChunk for DataPipes batching (pytorch#62768)
Summary: Pull Request resolved: pytorch#62768 This is part of TorchArrow DF support preparation, separating it to multiple PRs to simplify review process. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30149090 Pulled By: VitalyFedyunin fbshipit-source-id: a36b5ff56e2ac6b06060014d4cd41b487754acb8
1 parent 5e5de75 commit d3bdf34

File tree

8 files changed

+97
-39
lines changed

8 files changed

+97
-39
lines changed

test/delete.py

Whitespace-only changes.

test/test_datapipe.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from torch.testing._internal.common_utils import TestCase, run_tests
4343
from torch.utils.data import (
4444
DataLoader,
45+
DataChunk,
4546
IterDataPipe,
4647
MapDataPipe,
4748
RandomSampler,
@@ -108,6 +109,14 @@ def create_temp_dir_and_files():
108109
(temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)]
109110

110111

112+
class TestDataChunk(TestCase):
113+
114+
def test_as_string(self):
115+
elements = list(range(10))
116+
chunk : DataChunk[int] = DataChunk(elements)
117+
self.assertEquals(str(chunk), str(elements))
118+
119+
111120
class TestIterableDataPipeBasic(TestCase):
112121

113122
def setUp(self):
@@ -141,7 +150,6 @@ def test_listdirfiles_iterable_datapipe(self):
141150
self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files))
142151
self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))
143152

144-
145153
def test_loadfilesfromdisk_iterable_datapipe(self):
146154
# test import datapipe class directly
147155
from torch.utils.data.datapipes.iter import (
@@ -216,7 +224,6 @@ def test_readfilesfromzip_iterable_datapipe(self):
216224
self.assertEqual(data_ref[1].read(), f.read())
217225
data_ref[1].close()
218226

219-
220227
def test_routeddecoder_iterable_datapipe(self):
221228
temp_dir = self.temp_dir.name
222229
temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
@@ -697,7 +704,6 @@ def test_unbatch_datapipe(self):
697704
for i in unbatch_dp:
698705
print(i)
699706

700-
701707
def test_bucket_batch_datapipe(self):
702708
input_dp = IDP(range(20))
703709
with self.assertRaises(AssertionError):
@@ -787,7 +793,8 @@ def _filter_fn(data, val):
787793
for data, exp in zip(filter_dp, expected_dp1):
788794
self.assertEqual(data, exp)
789795

790-
filter_dp = input_ds.filter(nesting_level=-1, drop_empty_batches=False, filter_fn=_filter_fn, fn_kwargs={'val': 5})
796+
filter_dp = input_ds.filter(nesting_level=-1, drop_empty_batches=False,
797+
filter_fn=_filter_fn, fn_kwargs={'val': 5})
791798
expected_dp2: List[List[int]] = [[], [5, 6, 7, 8, 9]]
792799
self.assertEqual(len(list(filter_dp)), len(expected_dp2))
793800
for data, exp in zip(filter_dp, expected_dp2):
@@ -826,7 +833,6 @@ def _filter_fn(data, val):
826833
for data2, exp2 in zip(filter_dp, expected_dp6):
827834
self.assertEqual(data2, exp2)
828835

829-
830836
def test_sampler_datapipe(self):
831837
input_dp = IDP(range(10))
832838
# Default SequentialSampler
@@ -1153,6 +1159,7 @@ def __iter__(self) -> Iterator[T_co]:
11531159

11541160
class DP3(IterDataPipe[Tuple[T_co, str]]):
11551161
r""" DataPipe without fixed type with __init__ function"""
1162+
11561163
def __init__(self, datasource):
11571164
self.datasource = datasource
11581165

@@ -1168,6 +1175,7 @@ def __iter__(self) -> Iterator[Tuple[T_co, str]]:
11681175

11691176
class DP4(IterDataPipe[tuple]):
11701177
r""" DataPipe without __iter__ annotation"""
1178+
11711179
def __iter__(self):
11721180
raise NotImplementedError
11731181

@@ -1177,6 +1185,7 @@ def __iter__(self):
11771185

11781186
class DP5(IterDataPipe):
11791187
r""" DataPipe without type annotation"""
1188+
11801189
def __iter__(self) -> Iterator[str]:
11811190
raise NotImplementedError
11821191

@@ -1187,6 +1196,7 @@ def __iter__(self) -> Iterator[str]:
11871196

11881197
class DP6(IterDataPipe[int]):
11891198
r""" DataPipe with plain Iterator"""
1199+
11901200
def __iter__(self) -> Iterator:
11911201
raise NotImplementedError
11921202

@@ -1206,7 +1216,6 @@ class DP8(DP7[str]):
12061216
self.assertTrue(issubclass(DP8, IterDataPipe))
12071217
self.assertTrue(DP8.type.param == Awaitable[str])
12081218

1209-
12101219
def test_construct_time(self):
12111220
class DP0(IterDataPipe[Tuple]):
12121221
@argument_validation
@@ -1269,11 +1278,9 @@ def __iter__(self) -> Iterator[Tuple[int, T_co]]:
12691278
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
12701279
list(dp)
12711280

1272-
12731281
def test_reinforce(self):
12741282
T = TypeVar('T', int, str)
12751283

1276-
12771284
class DP(IterDataPipe[T]):
12781285
def __init__(self, ds):
12791286
self.ds = ds
@@ -1306,6 +1313,7 @@ def __iter__(self) -> Iterator[T]:
13061313
with runtime_validation_disabled():
13071314
self.assertEqual(list(d for d in dp), ds)
13081315

1316+
13091317
class NumbersDataset(IterDataPipe):
13101318
def __init__(self, size=10):
13111319
self.size = size
@@ -1321,7 +1329,7 @@ def test_simple_traverse(self):
13211329
numbers_dp = NumbersDataset(size=50)
13221330
mapped_dp = numbers_dp.map(lambda x: x * 10)
13231331
graph = torch.utils.data.graph.traverse(mapped_dp)
1324-
expected : Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}}
1332+
expected: Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}}
13251333
self.assertEqual(expected, graph)
13261334

13271335
# TODO(VitalyFedyunin): This test is incorrect because of 'buffer' nature
@@ -1377,5 +1385,6 @@ def test_old_dataloader(self):
13771385

13781386
self.assertEqual(sorted(expected), sorted(items))
13791387

1388+
13801389
if __name__ == '__main__':
13811390
run_tests()

torch/utils/data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ConcatDataset,
1414
Dataset,
1515
Dataset as MapDataPipe,
16+
DataChunk,
1617
IterableDataset,
1718
IterableDataset as IterDataPipe,
1819
Subset,

torch/utils/data/datapipes/iter/callable.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
import torch.nn as nn
3-
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
3+
from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
44
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
55

66
try:
@@ -68,14 +68,17 @@ def _apply(self, data, nesting_level):
6868
if nesting_level == 0:
6969
return self.fn(data, *self.args, **self.kwargs)
7070
elif nesting_level > 0:
71-
if not isinstance(data, list):
71+
if isinstance(data, DataChunk):
72+
return type(data)([self._apply(i, nesting_level - 1) for i in data.raw_iterator()])
73+
elif isinstance(data, list):
74+
return [self._apply(i, nesting_level - 1) for i in data]
75+
else:
7276
raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
73-
result = [self._apply(i, nesting_level - 1) for i in data]
74-
return result
7577
else:
76-
if isinstance(data, list):
77-
result = [self._apply(i, nesting_level) for i in data]
78-
return result
78+
if isinstance(data, DataChunk):
79+
return type(data)([self._apply(i, nesting_level) for i in data.raw_iterator()])
80+
elif isinstance(data, list):
81+
return [self._apply(i, nesting_level) for i in data]
7982
else:
8083
return self.fn(data, *self.args, **self.kwargs)
8184

@@ -162,6 +165,7 @@ class TransformsIterDataPipe(MapIterDataPipe):
162165
datapipe: Iterable DataPipe being transformed
163166
transforms: A transform or a sequence of transforms from torchvision or torchaudio.
164167
"""
168+
165169
def __init__(self,
166170
datapipe: IterDataPipe,
167171
transforms: Callable,

torch/utils/data/datapipes/iter/combinatorics.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def __init__(self,
3131
self.sampler_args = () if sampler_args is None else sampler_args
3232
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
3333
# https://github.com/python/mypy/pull/9629 will solve
34-
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args,
35-
**self.sampler_kwargs) # type: ignore[misc]
34+
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args, **self.sampler_kwargs) # type: ignore[misc]
3635

3736
def __iter__(self) -> Iterator[T_co]:
3837
return iter(self.sampler)

torch/utils/data/datapipes/iter/grouping.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from collections import defaultdict
66

7-
from torch.utils.data import IterDataPipe, functional_datapipe
7+
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
88
from typing import Any, Callable, Dict, Iterator, List, Optional, Sized, Tuple, TypeVar, DefaultDict
99

1010
T_co = TypeVar('T_co', covariant=True)
@@ -31,7 +31,7 @@ def __iter__(self):
3131

3232

3333
@functional_datapipe('batch')
34-
class BatchIterDataPipe(IterDataPipe[List[T_co]]):
34+
class BatchIterDataPipe(IterDataPipe[DataChunk[T_co]]):
3535
r""" :class:`BatchIterDataPipe`.
3636
3737
Iterable DataPipe to create mini-batches of data. An outer dimension will be added as
@@ -65,17 +65,18 @@ def __init__(self,
6565
self.batch_size = batch_size
6666
self.drop_last = drop_last
6767
self.length = None
68+
self.wrapper_class = DataChunk
6869

69-
def __iter__(self) -> Iterator[List[T_co]]:
70+
def __iter__(self) -> Iterator[DataChunk[T_co]]:
7071
batch: List[T_co] = []
7172
for x in self.datapipe:
7273
batch.append(x)
7374
if len(batch) == self.batch_size:
74-
yield batch
75+
yield self.wrapper_class(batch)
7576
batch = []
7677
if len(batch) > 0:
7778
if not self.drop_last:
78-
yield batch
79+
yield self.wrapper_class(batch)
7980
batch = []
8081

8182
def __len__(self) -> int:
@@ -115,7 +116,7 @@ def _dive(self, element, unbatch_level):
115116
if unbatch_level < -1:
116117
raise ValueError("unbatch_level must be -1 or >= 0")
117118
if unbatch_level == -1:
118-
if isinstance(element, list):
119+
if isinstance(element, list) or isinstance(element, DataChunk):
119120
for item in element:
120121
for i in self._dive(item, unbatch_level=-1):
121122
yield i
@@ -124,11 +125,12 @@ def _dive(self, element, unbatch_level):
124125
elif unbatch_level == 0:
125126
yield element
126127
else:
127-
if not isinstance(element, list):
128+
if isinstance(element, list) or isinstance(element, DataChunk):
129+
for item in element:
130+
for i in self._dive(item, unbatch_level=unbatch_level - 1):
131+
yield i
132+
else:
128133
raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe")
129-
for item in element:
130-
for i in self._dive(item, unbatch_level=unbatch_level - 1):
131-
yield i
132134

133135

134136
@functional_datapipe('bucket_batch')
@@ -175,11 +177,16 @@ def __init__(self,
175177
def __iter__(self) -> Iterator[List[T_co]]:
176178
# Bucket without sorting remains same order, directly returns BatchDataset
177179
if self.sort_key is None:
178-
yield from BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last)
180+
for element in BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last):
181+
if isinstance(element, DataChunk):
182+
yield list(element.raw_iterator())
183+
else:
184+
yield element
179185
else:
180186
bucket: List[T_co]
181187
batch: List[T_co] = []
182-
for bucket in self.bucket_ds:
188+
for bucket_or_chunk in self.bucket_ds:
189+
bucket = list(bucket_or_chunk)
183190
# In-place sort within bucket
184191
bucket.sort(key=self.sort_key)
185192
for start in range(0, len(bucket), self.batch_size):
@@ -255,6 +262,7 @@ def __init__(self,
255262
assert guaranteed_group_size > 0 and group_size is not None and guaranteed_group_size <= group_size
256263
self.guaranteed_group_size = guaranteed_group_size
257264
self.drop_remaining = drop_remaining
265+
self.wrapper_class = DataChunk
258266

259267
def _remove_biggest_key(self, buffer_elements, buffer_size):
260268
biggest_key = None
@@ -283,22 +291,22 @@ def __iter__(self):
283291
key = self.group_key_fn(x)
284292

285293
if self.group_size is not None and self.group_size == len(buffer_elements[key]):
286-
yield buffer_elements[key]
294+
yield self.wrapper_class(buffer_elements[key])
287295
buffer_size -= len(buffer_elements[key])
288296
del buffer_elements[key]
289297

290298
if buffer_size == self.buffer_size:
291299
(result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
292300
if result_to_yield is not None:
293-
yield result_to_yield
301+
yield self.wrapper_class(result_to_yield)
294302

295303
buffer_elements[key].append(x)
296304
buffer_size += 1
297305

298306
while buffer_size:
299307
(result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
300308
if result_to_yield is not None:
301-
yield result_to_yield
309+
yield self.wrapper_class(result_to_yield)
302310

303311

304312
@functional_datapipe('group_by_key')

torch/utils/data/datapipes/iter/selecting.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch.utils.data import IterDataPipe, functional_datapipe
1+
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
22
from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
33

44
from .callable import MapIterDataPipe
@@ -45,12 +45,20 @@ def _applyFilter(self, data, nesting_level):
4545
if nesting_level == 0:
4646
return self._returnIfTrue(data)
4747
elif nesting_level > 0:
48-
if not isinstance(data, list):
48+
if isinstance(data, DataChunk):
49+
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1)
50+
for i in data.raw_iterator()])
51+
return type(data)(list(result))
52+
elif isinstance(data, list):
53+
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1) for i in data])
54+
return list(result)
55+
else:
4956
raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
50-
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1) for i in data])
51-
return list(result)
5257
else: # Handling nesting_level == -1
53-
if isinstance(data, list):
58+
if isinstance(data, DataChunk):
59+
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level) for i in data.raw_iterator()])
60+
return type(data)(list(result))
61+
elif isinstance(data, list):
5462
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level) for i in data])
5563
return list(result)
5664
else:
@@ -64,7 +72,10 @@ def _returnIfTrue(self, data):
6472
return data
6573

6674
def _isNonEmpty(self, data):
67-
return data is not None and not (data == [] and self.drop_empty_batches)
75+
r = data is not None and \
76+
not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
77+
return r
78+
6879

6980
def __len__(self):
7081
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))

torch/utils/data/dataset.py

+26
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,32 @@
2525
T = TypeVar('T')
2626

2727

28+
class DataChunk(List[T]):
29+
def __init__(self, items):
30+
self.items = items
31+
32+
def __getitem__(self, key):
33+
return self.items[key]
34+
35+
def __len__(self):
36+
return len(self.items)
37+
38+
def as_str(self, indent=''):
39+
res = indent + "[" + ", ".join([str(i) for i in iter(self)]) + "]"
40+
return res
41+
42+
def __str__(self):
43+
return self.as_str()
44+
45+
def __iter__(self) -> Iterator[T]:
46+
for i in self.items:
47+
yield i
48+
49+
def raw_iterator(self):
50+
for i in self.items:
51+
yield i
52+
53+
2854
class Dataset(Generic[T_co]):
2955
r"""An abstract class representing a :class:`Dataset`.
3056

0 commit comments

Comments
 (0)