Skip to content
35 changes: 31 additions & 4 deletions tensorflow_datasets/core/download/checksums.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,15 @@ class MyDataset(tfds.core.DatasetBuilder):


def _list_dir(path: str) -> List[str]:
return tf.io.gfile.listdir(path)
"""Returns a list of entries contained within the given directory.

Args:
path: Path to the directory.

Returns:
List of entries contained within the given directory.
"""
return tf.io.gfile.listdir(path)


@utils.memoize()
Expand All @@ -124,7 +130,14 @@ def _checksum_paths() -> Dict[str, str]:


def _get_path(dataset_name: str) -> str:
"""Returns path to where checksums are stored for a given dataset."""
"""Returns path to where checksums are stored for a given dataset.

Args:
dataset_name: Name of the dataset.

Returns:
Path to where the checksums for the given dataset are stored.
"""
path = _checksum_paths().get(dataset_name, None)
if path:
return path
Expand All @@ -142,14 +155,28 @@ def _get_path(dataset_name: str) -> str:


def _get_url_infos(checksums_path: str) -> Dict[str, UrlInfo]:
"""Returns {URL: (size, checksum)}s stored within file at given path."""
"""Returns {URL: (size, checksum)}s stored within file at given path.

Args:
checksums_path: Path to the checksums.

Returns:
Dict mapping the URLs to their corresponding UrlInfos.
"""
with tf.io.gfile.GFile(checksums_path) as f:
content = f.read()
return parse_url_infos(content.splitlines())


def parse_url_infos(checksums_file: Iterable[str]) -> Dict[str, UrlInfo]:
"""Returns {URL: (size, checksum)}s stored within given file."""
"""Returns {URL: (size, checksum)}s stored within given file.

Args:
checksums_file: List of checksums.

Returns:
Dict mapping URLs to their corresponding UrlInfos.
"""
url_infos = {}
for line in checksums_file:
line = line.strip() # Remove the trailing '\r' on Windows OS.
Expand Down
87 changes: 79 additions & 8 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,15 @@ def __init__(
self._executor = concurrent.futures.ThreadPoolExecutor(1)

def __getstate__(self):
"""Remove un-pickleable attributes and return the state."""
"""Remove un-pickleable attributes and return the state.

Returns:
The state of the download manager.

Raises:
NotImplementedError: If the register_checksums flag is enabled
in a parallelized download manager.
"""
if self._register_checksums:
# Currently, checksums registration from Beam not supported.
raise NotImplementedError(
Expand All @@ -241,12 +249,14 @@ def __getstate__(self):

@property
def _downloader(self):
"""Returns the downloader object."""
if not self.__downloader:
self.__downloader = downloader.get_downloader()
return self.__downloader

@property
def _extractor(self):
"""Returns the extractor object."""
if not self.__extractor:
self.__extractor = extractor.get_extractor()
return self.__extractor
Expand All @@ -257,6 +267,15 @@ def downloaded_size(self):
return sum(url_info.size for url_info in self._recorded_url_infos.values())

def _get_final_dl_path(self, url, sha256):
"""Returns the final download path.

Args:
url: The download url.
sha256: The sha256 hash hexdump.

Returns:
The final download path.
"""
return os.path.join(self._download_dir,
resource_lib.get_dl_fname(url, sha256))

Expand Down Expand Up @@ -295,7 +314,9 @@ def _handle_download_result(
dst_path: `url_path` (or `file_path` when `register_checksums=True`)

Raises:
NonMatchingChecksumError:
ValueError: If the number of files found in the tmp dir is not 1 and if
the checksum was not registered.
NonMatchingChecksumError: If the checksums do not match.
"""
# Extract the file name, path from the tmp_dir
fnames = tf.io.gfile.listdir(tmp_dir_path)
Expand Down Expand Up @@ -414,7 +435,14 @@ def _find_existing_path(self, url: str, url_path: str) -> Optional[str]:
return existing_path

def download_checksums(self, checksums_url):
"""Downloads checksum file from the given URL and adds it to registry."""
"""Downloads checksum file from the given URL and adds it to registry.

Args:
checksums_url: The checksum url.

Returns:
Updated url registry.
"""
checksums_path = self.download(checksums_url)
with tf.io.gfile.GFile(checksums_path) as f:
self._url_infos.update(checksums.parse_url_infos(f))
Expand Down Expand Up @@ -475,6 +503,7 @@ def _download(self, resource: Union[str, resource_lib.Resource]):
'%s.tmp.%s' % (resource_lib.get_dl_dirname(url), uuid.uuid4().hex))
tf.io.gfile.makedirs(download_dir_path)
logging.info('Downloading %s into %s...', url, download_dir_path)

def callback(url_info):
return self._handle_download_result(
resource=resource,
Expand All @@ -487,7 +516,14 @@ def callback(url_info):
@utils.build_synchronize_decorator()
@utils.memoize()
def _extract(self, resource):
"""Extract a single archive, returns Promise->path to extraction result."""
"""Extract a single archive, returns Promise->path to extraction result.

Args:
resource: The path to the file to extract.

Return:
The resolved promise.
"""
if isinstance(resource, six.string_types):
resource = resource_lib.Resource(path=resource)
path = resource.path
Expand All @@ -506,9 +542,17 @@ def _extract(self, resource):
@utils.build_synchronize_decorator()
@utils.memoize()
def _download_extract(self, resource):
"""Download-extract `Resource` or url, returns Promise->path."""
"""Download-extract `Resource` or url, returns Promise->path.

Args:
resource: The url to download data from.

Returns:
The resolved promise.
"""
if isinstance(resource, six.string_types):
resource = resource_lib.Resource(url=resource)

def callback(path):
resource.path = path
return self._extract(resource)
Expand Down Expand Up @@ -605,7 +649,14 @@ def download_and_extract(self, url_or_urls):

@property
def manual_dir(self):
"""Returns the directory containing the manually extracted data."""
"""Returns the directory containing the manually extracted data.

Returns:
The path to the dir containing the manually extracted data.

Raises:
AssertionError: If the Manual directory does not exist or is empty
"""
if not self._manual_dir:
raise AssertionError(
'Manual directory was enabled. '
Expand All @@ -621,7 +672,19 @@ def manual_dir(self):


def _read_url_info(url_path: str) -> checksums.UrlInfo:
"""Loads the `UrlInfo` from the `.INFO` file."""
"""Loads the `UrlInfo` from the `.INFO` file.

Args:
url_path: The path to the .INFO file.

Returns:
UrlInfo object.

Raises:
ValueError: If 'url_info' is not found in the .INFO file.
This likely indicates the files were downloaded
with a previous version of TFDS (<=3.1.0).
"""
file_info = resource_lib.read_info_file(url_path)
if 'url_info' not in file_info:
raise ValueError(
Expand All @@ -636,7 +699,15 @@ def _wait_on_promise(p):


def _map_promise(map_fn, all_inputs):
"""Map the function into each element and resolve the promise."""
"""Map the function into each element and resolve the promise.

Args:
map_fn: The function to be mapped into each element.
all_inputs: The elements the function is to be mapped into.

Returns:
The resolved promise.
"""
all_promises = tf.nest.map_structure(map_fn, all_inputs) # Apply the function
res = tf.nest.map_structure(_wait_on_promise, all_promises)
return res
36 changes: 35 additions & 1 deletion tensorflow_datasets/core/download/download_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,19 @@


def _sha256(str_):
"""Returns the SHA256 hexdump of the given string.

Args:
str_: String to compute the SHA256 hexdump of.

Returns:
SHA256 hexdump of the given string.
"""
return hashlib.sha256(str_.encode('utf8')).hexdigest()


class Artifact(object):
"""Artifact class for tracking files used for testing."""
# For testing only.

def __init__(self, name, url=None):
Expand Down Expand Up @@ -86,6 +95,15 @@ def _make_downloader_mock(self):
"""`downloader.download` patch which creates the returns the path."""

def _download(url, tmpdir_path):
"""Download function of the DownloadManager.

Args:
url: URL to download from.
tmpdir_path: Path to the temporary directory.

Returns:
The resolved promise.
"""
self.downloaded_urls.append(url) # Record downloader.download() calls
# If the name isn't explicitly provided, then it is extracted from the
# url.
Expand Down Expand Up @@ -135,6 +153,12 @@ def setUp(self):
self.addCleanup(absltest.mock.patch.stopall)

def _write_info(self, path, info):
"""Writes the content to the .INFO file.

Args:
path: Path to the .INFO file.
info: Content to be written.
"""
content = json.dumps(info)
self.fs.add_file(path, content)

Expand All @@ -146,6 +170,17 @@ def _get_manager(
extract_dir='/extract_dir',
**kwargs
):
"""Returns the DownloadManager object.

Args:
register_checksums: Whether or not to register the checksums.
url_infos: UrlInfos for the URLs.
dl_dir: Path to the download directory.
extract_dir: Path to the extraction directory.

Returns:
DownloadManager object.
"""
manager = dm.DownloadManager(
dataset_name='mnist',
download_dir=dl_dir,
Expand Down Expand Up @@ -257,7 +292,6 @@ def test_download_and_extract_archive_ext_in_fname(self):
'a': '/extract_dir/ZIP.%s' % a.file_name,
})


def test_download_and_extract_already_downloaded(self):
a = Artifact('a') # Extract can't be deduced from the url, but from .INFO
# File was already downloaded:
Expand Down
Loading