diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 9ee73ba44..ea1fa3424 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -1056,6 +1056,18 @@ def test_itertomap_mapdatapipe(self): self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Found duplicate key") + # More lazily: load only until necessary + source_dp = IterableWrapper(list(zip(keys, values))) + lazy_map_dp = source_dp.to_map_datapipe() + _ = lazy_map_dp["k" + str(4)] + self.assertEqual(len(lazy_map_dp._map), 5) + _ = lazy_map_dp["k" + str(7)] + self.assertEqual(len(lazy_map_dp._map), 8) + try: + _ = lazy_map_dp["k" + str(20)] + except IndexError: + self.assertEqual(len(lazy_map_dp._map), 10) + def test_mux_longest_iterdatapipe(self): # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index 0721e1741..59d68a64e 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -5,8 +5,7 @@ # LICENSE file in the root directory of this source tree. import warnings - -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Iterator, Optional from torch.utils.data import IterDataPipe, MapDataPipe from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE @@ -37,28 +36,43 @@ class IterToMapConverterMapDataPipe(MapDataPipe): will be replaced by the new value. Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper([(i, i) for i in range(10)]) - >>> map_dp = source_dp.to_map_datapipe() - >>> list(map_dp) - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - >>> source_dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 1)]) - >>> map_dp2 = source_dp2.to_map_datapipe() - >>> map_dp2['a'] - 1 - >>> def row_to_tuple(row): - >>> label = row[0] - >>> data = row[1:] - >>> return label, data - >>> source_dp3 = IterableWrapper([('a', 1, 1, 1, 1, 1, 1), ('b', 2, 2, 2, 2, 2, 2), ('c', 3, 3, 3, 3, 3, 3)]) - >>> map_dp3 = source_dp3.to_map_datapipe(key_value_fn=row_to_tuple) - >>> map_dp3['a'] + + .. testsetup:: + + from torchdata.datapipes.iter import IterableWrapper + + .. testcode:: + + source_dp = IterableWrapper([(i, i) for i in range(10)]) + map_dp = source_dp.to_map_datapipe() + assert list(map_dp) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + source_dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 1)]) + map_dp2 = source_dp2.to_map_datapipe() + assert map_dp2['a'] == 1 + + .. testcode:: + + def row_to_tuple(row): + label = row[0] + data = row[1:] + return label, data + source_dp3 = IterableWrapper([('a', 1, 1, 1, 1, 1, 1), ('b', 2, 2, 2, 2, 2, 2), ('c', 3, 3, 3, 3, 3, 3)]) + map_dp3 = source_dp3.to_map_datapipe(key_value_fn=row_to_tuple) + print(map_dp3['a']) + + .. testoutput:: + (1, 1, 1, 1, 1, 1) + """ + datapipe: IterDataPipe key_value_fn: Optional[Callable] _map: Optional[Dict] _length: int + _itr: Optional[Iterator] + _depleted: bool def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = None): if not isinstance(datapipe, IterDataPipe): @@ -68,32 +82,53 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No _check_unpickable_fn(key_value_fn) self.key_value_fn = key_value_fn # type: ignore[assignment] self._map = None + self._itr = None + self._depleted = False def _load_map(self): - self._map = {} - for d in self.datapipe: - inp = d if self.key_value_fn is None else self.key_value_fn(d) + while not self._depleted: try: - length = len(inp) - except TypeError: - raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence") - if length != 2: - raise ValueError(f"dictionary update sequence element has length {length}, 2 is required") - key, value = inp - if key in self._map: - warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`") - self._map[key] = value + self._load_next_item() + except StopIteration: + self._depleted = True + self._itr = None def __getitem__(self, index): try: - if self._map is None: - self._load_map() - return self._map[index] # type: ignore[index] + if self._map is not None: + return self._map[index] except KeyError: - raise IndexError(f"Index {index} is invalid for IterToMapConverter.") + pass + while not self._depleted: + try: + key, value = self._load_next_item() + if key == index: + return value + except StopIteration: + self._depleted = True + self._itr = None + raise IndexError(f"Index {index} is invalid for IterToMapConverter.") + + def _load_next_item(self): + if self._map is None: + self._map = {} + self._itr = iter(self.datapipe) + elem = next(self._itr) # type: ignore[arg-type] + inp = elem if self.key_value_fn is None else self.key_value_fn(elem) + try: + length = len(inp) + except TypeError: + raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence") + if length != 2: + raise ValueError(f"dictionary update sequence element has length {length}, 2 is required") + key, value = inp + if key in self._map: + warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`") + self._map[key] = value + return key, value def __len__(self): - if self._map is not None: + if self._depleted: return len(self._map) # type: ignore[arg-type] try: return len(self.datapipe) @@ -112,14 +147,10 @@ def __getstate__(self): dill_key_value_fn = dill.dumps(self.key_value_fn) else: dill_key_value_fn = self.key_value_fn - return ( - self.datapipe, - dill_key_value_fn, - self._map, - ) + return (self.datapipe, dill_key_value_fn, self._map, self._depleted) def __setstate__(self, state): - (self.datapipe, dill_key_value_fn, self._map) = state + (self.datapipe, dill_key_value_fn, self._map, self._depleted) = state if DILL_AVAILABLE: self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment] else: