|
20 | 20 | ChainDataset,
|
21 | 21 | ConcatDataset,
|
22 | 22 | DataLoader,
|
23 |
| - DataLoader2, |
24 | 23 | Dataset,
|
25 | 24 | IterableDataset,
|
26 | 25 | IterDataPipe,
|
27 | 26 | Subset,
|
28 | 27 | TensorDataset,
|
29 |
| - communication, |
30 | 28 | _utils
|
31 | 29 | )
|
32 | 30 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
33 | 31 | from torch.utils.data.dataset import random_split
|
34 | 32 | from torch.utils.data.datapipes.iter import IterableWrapper
|
35 |
| -from torch.utils.data.datapipes.map import SequenceWrapper |
36 | 33 | from torch._utils import ExceptionWrapper
|
37 | 34 | from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS,
|
38 | 35 | IS_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest,
|
@@ -2222,114 +2219,6 @@ def test_excessive_thread_creation_warning(self):
|
2222 | 2219 | r"excessive worker creation might get DataLoader running slow or even freeze"):
|
2223 | 2220 | dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
|
2224 | 2221 |
|
2225 |
| -# Define a global function for testing purposes since local functions cannot be pickled |
2226 |
| -def identity(x): |
2227 |
| - return x |
2228 |
| - |
2229 |
| -@unittest.skipIf( |
2230 |
| - TEST_WITH_TSAN, |
2231 |
| - "Fails with TSAN with the following error: starting new threads after multi-threaded " |
2232 |
| - "fork is not supported. Dying (set die_after_fork=0 to override)") |
2233 |
| -class TestDataLoader2(TestCase): |
2234 |
| - @skipIfNoDill |
2235 |
| - def test_basics(self): |
2236 |
| - # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order |
2237 |
| - # of traversing workers |
2238 |
| - dp = IterableWrapper(list(range(1000))).sharding_filter() |
2239 |
| - dl = DataLoader(dp, batch_size=3, collate_fn=identity, num_workers=2) |
2240 |
| - dl2 = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2) |
2241 |
| - dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2, parallelism_mode='thread') |
2242 |
| - self.assertEqual(list(dl), list(dl2)) |
2243 |
| - self.assertEqual(list(dl), list(dl2_threading)) |
2244 |
| - |
2245 |
| - class Sorter(IterDataPipe): |
2246 |
| - def __init__(self, datapipe): |
2247 |
| - self.datapipe = datapipe |
2248 |
| - |
2249 |
| - def __iter__(self): |
2250 |
| - return iter(sorted(self.datapipe)) |
2251 |
| - |
2252 |
| - def test_shuffle(self): |
2253 |
| - items = list(range(1000)) |
2254 |
| - dp = IterableWrapper(items).sharding_filter().shuffle() |
2255 |
| - |
2256 |
| - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=False) |
2257 |
| - self.assertEqual(items, list(dl)) |
2258 |
| - |
2259 |
| - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) |
2260 |
| - self.assertNotEqual(items, list(dl)) |
2261 |
| - self.assertEqual(items, sorted(list(dl))) |
2262 |
| - |
2263 |
| - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) |
2264 |
| - self.assertNotEqual(items, list(dl)) |
2265 |
| - self.assertEqual(items, sorted(list(dl))) |
2266 |
| - |
2267 |
| - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) |
2268 |
| - self.assertEqual(list(dl), items) |
2269 |
| - |
2270 |
| - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) |
2271 |
| - self.assertEqual(list(dl), items) |
2272 |
| - |
2273 |
| - |
2274 |
| -@unittest.skipIf( |
2275 |
| - TEST_WITH_TSAN, |
2276 |
| - "Fails with TSAN with the following error: starting new threads after multi-threaded " |
2277 |
| - "fork is not supported. Dying (set die_after_fork=0 to override)") |
2278 |
| -class TestDataLoader2_EventLoop(TestCase): |
2279 |
| - @skipIfNoDill |
2280 |
| - def test_basic_threading(self): |
2281 |
| - def clean_me(process, req_queue, res_queue): |
2282 |
| - req_queue.put(communication.messages.TerminateRequest()) |
2283 |
| - _ = res_queue.get() |
2284 |
| - process.join() |
2285 |
| - |
2286 |
| - it = list(range(100)) |
2287 |
| - numbers_dp = IterableWrapper(it) |
2288 |
| - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp) |
2289 |
| - |
2290 |
| - process.start() |
2291 |
| - local_datapipe = communication.iter.QueueWrapper( |
2292 |
| - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) |
2293 |
| - |
2294 |
| - actual = list(local_datapipe) |
2295 |
| - clean_me(process, req_queue, res_queue) |
2296 |
| - |
2297 |
| - self.assertEqual(list(range(100)), actual) |
2298 |
| - |
2299 |
| - @skipIfNoDill |
2300 |
| - def test_basic_mapdatapipe_threading(self): |
2301 |
| - def clean_me(process, req_queue, res_queue): |
2302 |
| - req_queue.put(communication.messages.TerminateRequest()) |
2303 |
| - _ = res_queue.get() |
2304 |
| - process.join() |
2305 |
| - |
2306 |
| - input_len = 100 |
2307 |
| - it = list(range(input_len)) |
2308 |
| - numbers_dp = SequenceWrapper(it) |
2309 |
| - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline( |
2310 |
| - numbers_dp) |
2311 |
| - |
2312 |
| - process.start() |
2313 |
| - |
2314 |
| - # Functional Test: Ensure that you can retrieve every element from the Queue and DataPipe |
2315 |
| - local_datapipe = communication.map.QueueWrapperForMap( |
2316 |
| - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) |
2317 |
| - actual = list(local_datapipe) |
2318 |
| - self.assertEqual([(x, x) for x in range(100)], actual) |
2319 |
| - |
2320 |
| - # Functional Test: raise Error when input |
2321 |
| - local_datapipe = communication.map.QueueWrapperForMap( |
2322 |
| - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) |
2323 |
| - with self.assertRaisesRegex(IndexError, "out of bound"): |
2324 |
| - local_datapipe[1000] |
2325 |
| - |
2326 |
| - # __len__ Test: Ensure that the correct length is returned |
2327 |
| - local_datapipe = communication.map.QueueWrapperForMap( |
2328 |
| - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) |
2329 |
| - self.assertEqual(input_len, len(local_datapipe)) |
2330 |
| - |
2331 |
| - clean_me(process, req_queue, res_queue) |
2332 |
| - |
2333 | 2222 |
|
2334 | 2223 | class IntegrationTestDataLoaderDataPipe(TestCase):
|
2335 | 2224 | r"""
|
|
0 commit comments