42
42
from torch .testing ._internal .common_utils import TestCase , run_tests
43
43
from torch .utils .data import (
44
44
DataLoader ,
45
+ DataChunk ,
45
46
IterDataPipe ,
46
47
MapDataPipe ,
47
48
RandomSampler ,
@@ -108,6 +109,14 @@ def create_temp_dir_and_files():
108
109
(temp_sub_dir , temp_sub_file1_name , temp_sub_file2_name )]
109
110
110
111
112
+ class TestDataChunk (TestCase ):
113
+
114
+ def test_as_string (self ):
115
+ elements = list (range (10 ))
116
+ chunk : DataChunk [int ] = DataChunk (elements )
117
+ self .assertEquals (str (chunk ), str (elements ))
118
+
119
+
111
120
class TestIterableDataPipeBasic (TestCase ):
112
121
113
122
def setUp (self ):
@@ -141,7 +150,6 @@ def test_listdirfiles_iterable_datapipe(self):
141
150
self .assertTrue ((pathname in self .temp_files ) or (pathname in self .temp_sub_files ))
142
151
self .assertEqual (count , len (self .temp_files ) + len (self .temp_sub_files ))
143
152
144
-
145
153
def test_loadfilesfromdisk_iterable_datapipe (self ):
146
154
# test import datapipe class directly
147
155
from torch .utils .data .datapipes .iter import (
@@ -216,7 +224,6 @@ def test_readfilesfromzip_iterable_datapipe(self):
216
224
self .assertEqual (data_ref [1 ].read (), f .read ())
217
225
data_ref [1 ].close ()
218
226
219
-
220
227
def test_routeddecoder_iterable_datapipe (self ):
221
228
temp_dir = self .temp_dir .name
222
229
temp_pngfile_pathname = os .path .join (temp_dir , "test_png.png" )
@@ -697,7 +704,6 @@ def test_unbatch_datapipe(self):
697
704
for i in unbatch_dp :
698
705
print (i )
699
706
700
-
701
707
def test_bucket_batch_datapipe (self ):
702
708
input_dp = IDP (range (20 ))
703
709
with self .assertRaises (AssertionError ):
@@ -787,7 +793,8 @@ def _filter_fn(data, val):
787
793
for data , exp in zip (filter_dp , expected_dp1 ):
788
794
self .assertEqual (data , exp )
789
795
790
- filter_dp = input_ds .filter (nesting_level = - 1 , drop_empty_batches = False , filter_fn = _filter_fn , fn_kwargs = {'val' : 5 })
796
+ filter_dp = input_ds .filter (nesting_level = - 1 , drop_empty_batches = False ,
797
+ filter_fn = _filter_fn , fn_kwargs = {'val' : 5 })
791
798
expected_dp2 : List [List [int ]] = [[], [5 , 6 , 7 , 8 , 9 ]]
792
799
self .assertEqual (len (list (filter_dp )), len (expected_dp2 ))
793
800
for data , exp in zip (filter_dp , expected_dp2 ):
@@ -826,7 +833,6 @@ def _filter_fn(data, val):
826
833
for data2 , exp2 in zip (filter_dp , expected_dp6 ):
827
834
self .assertEqual (data2 , exp2 )
828
835
829
-
830
836
def test_sampler_datapipe (self ):
831
837
input_dp = IDP (range (10 ))
832
838
# Default SequentialSampler
@@ -1153,6 +1159,7 @@ def __iter__(self) -> Iterator[T_co]:
1153
1159
1154
1160
class DP3 (IterDataPipe [Tuple [T_co , str ]]):
1155
1161
r""" DataPipe without fixed type with __init__ function"""
1162
+
1156
1163
def __init__ (self , datasource ):
1157
1164
self .datasource = datasource
1158
1165
@@ -1168,6 +1175,7 @@ def __iter__(self) -> Iterator[Tuple[T_co, str]]:
1168
1175
1169
1176
class DP4 (IterDataPipe [tuple ]):
1170
1177
r""" DataPipe without __iter__ annotation"""
1178
+
1171
1179
def __iter__ (self ):
1172
1180
raise NotImplementedError
1173
1181
@@ -1177,6 +1185,7 @@ def __iter__(self):
1177
1185
1178
1186
class DP5 (IterDataPipe ):
1179
1187
r""" DataPipe without type annotation"""
1188
+
1180
1189
def __iter__ (self ) -> Iterator [str ]:
1181
1190
raise NotImplementedError
1182
1191
@@ -1187,6 +1196,7 @@ def __iter__(self) -> Iterator[str]:
1187
1196
1188
1197
class DP6 (IterDataPipe [int ]):
1189
1198
r""" DataPipe with plain Iterator"""
1199
+
1190
1200
def __iter__ (self ) -> Iterator :
1191
1201
raise NotImplementedError
1192
1202
@@ -1206,7 +1216,6 @@ class DP8(DP7[str]):
1206
1216
self .assertTrue (issubclass (DP8 , IterDataPipe ))
1207
1217
self .assertTrue (DP8 .type .param == Awaitable [str ])
1208
1218
1209
-
1210
1219
def test_construct_time (self ):
1211
1220
class DP0 (IterDataPipe [Tuple ]):
1212
1221
@argument_validation
@@ -1269,11 +1278,9 @@ def __iter__(self) -> Iterator[Tuple[int, T_co]]:
1269
1278
with self .assertRaisesRegex (RuntimeError , r"Expected an instance as subtype" ):
1270
1279
list (dp )
1271
1280
1272
-
1273
1281
def test_reinforce (self ):
1274
1282
T = TypeVar ('T' , int , str )
1275
1283
1276
-
1277
1284
class DP (IterDataPipe [T ]):
1278
1285
def __init__ (self , ds ):
1279
1286
self .ds = ds
@@ -1306,6 +1313,7 @@ def __iter__(self) -> Iterator[T]:
1306
1313
with runtime_validation_disabled ():
1307
1314
self .assertEqual (list (d for d in dp ), ds )
1308
1315
1316
+
1309
1317
class NumbersDataset (IterDataPipe ):
1310
1318
def __init__ (self , size = 10 ):
1311
1319
self .size = size
@@ -1321,7 +1329,7 @@ def test_simple_traverse(self):
1321
1329
numbers_dp = NumbersDataset (size = 50 )
1322
1330
mapped_dp = numbers_dp .map (lambda x : x * 10 )
1323
1331
graph = torch .utils .data .graph .traverse (mapped_dp )
1324
- expected : Dict [Any , Any ] = {mapped_dp : {numbers_dp : {}}}
1332
+ expected : Dict [Any , Any ] = {mapped_dp : {numbers_dp : {}}}
1325
1333
self .assertEqual (expected , graph )
1326
1334
1327
1335
# TODO(VitalyFedyunin): This test is incorrect because of 'buffer' nature
@@ -1377,5 +1385,6 @@ def test_old_dataloader(self):
1377
1385
1378
1386
self .assertEqual (sorted (expected ), sorted (items ))
1379
1387
1388
+
1380
1389
if __name__ == '__main__' :
1381
1390
run_tests ()
0 commit comments