diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 8fc4cb16a..74a81a449 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -936,6 +936,5 @@ def test_zip_longest_iterdatapipe(self): # __len__ Test: length matches the length of the shortest input self.assertEqual(len(output_dp), 10) - if __name__ == "__main__": unittest.main() diff --git a/test/test_local_io.py b/test/test_local_io.py index f5a0cda61..4163e441e 100644 --- a/test/test_local_io.py +++ b/test/test_local_io.py @@ -26,6 +26,7 @@ CSVDictParser, CSVParser, Decompressor, + FileCache, FileLister, FileOpener, HashChecker, @@ -779,5 +780,22 @@ def decode(item): assert items[9][".bin"] == "bin9" + def test_filecache(self) -> None: + nfiles = 100 + testdata = b"hello, world" + dest = os.path.join(self.temp_dir.name, "testdata") + with open(dest, "wb") as stream: + stream.write(testdata) + stage1 = IterableWrapper([dest] * nfiles) + stage2 = FileOpener(stage1, mode="b") + stage3 = FileCache(stage2, cachedir=os.path.join(self.temp_dir.name, "_cache")) + count = 0 + for path, stream in stage3: + data = stream.read() + count += 1 + assert data == testdata + assert count == nfiles + + if __name__ == "__main__": unittest.main() diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index a4109cced..545b6bf2a 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -109,7 +109,12 @@ TFRecordLoaderIterDataPipe as TFRecordLoader, ) from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper -from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset +from torchdata.datapipes.iter.util.webdataset import ( + WebDatasetIterDataPipe as WebDataset, +) +from torchdata.datapipes.iter.util.filecache import ( + FileCacheIterDataPipe as FileCache, +) from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, XzFileReaderIterDataPipe as XzFileReader, @@ -140,6 +145,7 @@ "FSSpecFileLister", "FSSpecFileOpener", "FSSpecSaver", + "FileCache", "FileLister", "FileOpener", "Filter", diff --git a/torchdata/datapipes/iter/util/filecache.py b/torchdata/datapipes/iter/util/filecache.py new file mode 100644 index 000000000..1c52f3b9f --- /dev/null +++ b/torchdata/datapipes/iter/util/filecache.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import re +import shutil +import sys +import urllib.parse +from typing import Any, Dict, Iterator, Tuple + +from torch.utils.data.datapipes.utils.common import StreamWrapper + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + + +def cache_by_fname(s): + result = re.sub("^.*/", "", s) + return urllib.parse.quote(result) + + +if os.name == "nt": + default_cachedir = "datacache" +else: + default_cachedir = "_datacache" + + +@functional_datapipe("filecache") +class FileCacheIterDataPipe(IterDataPipe[Dict]): + r""" """ + + def __init__( + self, + source_datapipe: IterDataPipe[Tuple[str, Any]], + cachedir=default_cachedir, + cachename=cache_by_fname, + chunksize=1024**2, + verbose=False, + makedir=True, + ) -> None: + super().__init__() + if not os.path.exists(cachedir): + if makedir: + os.makedirs(cachedir) + else: + raise ValueError(f"Cache directory {cachedir} does not exist.") + self.source_datapipe: IterDataPipe[Tuple[str, Any]] = source_datapipe + self.cachedir = cachedir + self.cachename = cachename + self.verbose = verbose + self.chunksize = chunksize + + def __iter__(self) -> Iterator[Dict]: + for url, stream in self.source_datapipe: + cached = os.path.join(self.cachedir, self.cachename(url)) + if not os.path.exists(cached): + if self.verbose: + print(f"# downloading {url} to {cached}", file=sys.stderr) + with open(cached + ".temp", "wb") as dest: + shutil.copyfileobj(stream, dest, self.chunksize) + os.rename(cached + ".temp", cached) + if self.verbose: + print(f"# returning {cached}", file=sys.stderr) + cached_stream = open(cached, "rb") + yield url, StreamWrapper(cached_stream) + + def __len__(self) -> int: + return len(self.source_datapipe)