diff --git a/tensorflow_datasets/core/download/checksums.py b/tensorflow_datasets/core/download/checksums.py index dd827cdf3ec..35b01328498 100644 --- a/tensorflow_datasets/core/download/checksums.py +++ b/tensorflow_datasets/core/download/checksums.py @@ -104,9 +104,15 @@ class MyDataset(tfds.core.DatasetBuilder): def _list_dir(path: str) -> List[str]: - return tf.io.gfile.listdir(path) + """Returns a list of entries contained within the given directory. + Args: + path: Path to the directory. + Returns: + List of entries contained within the given directory. + """ + return tf.io.gfile.listdir(path) @utils.memoize() @@ -124,7 +130,14 @@ def _checksum_paths() -> Dict[str, str]: def _get_path(dataset_name: str) -> str: - """Returns path to where checksums are stored for a given dataset.""" + """Returns path to where checksums are stored for a given dataset. + + Args: + dataset_name: Name of the dataset. + + Returns: + Path to where the checksums for the given dataset are stored. + """ path = _checksum_paths().get(dataset_name, None) if path: return path @@ -142,14 +155,28 @@ def _get_path(dataset_name: str) -> str: def _get_url_infos(checksums_path: str) -> Dict[str, UrlInfo]: - """Returns {URL: (size, checksum)}s stored within file at given path.""" + """Returns {URL: (size, checksum)}s stored within file at given path. + + Args: + checksums_path: Path to the checksums. + + Returns: + Dict mapping the URLs to their corresponding UrlInfos. + """ with tf.io.gfile.GFile(checksums_path) as f: content = f.read() return parse_url_infos(content.splitlines()) def parse_url_infos(checksums_file: Iterable[str]) -> Dict[str, UrlInfo]: - """Returns {URL: (size, checksum)}s stored within given file.""" + """Returns {URL: (size, checksum)}s stored within given file. + + Args: + checksums_file: List of checksums. + + Returns: + Dict mapping URLs to their corresponding UrlInfos. + """ url_infos = {} for line in checksums_file: line = line.strip() # Remove the trailing '\r' on Windows OS. diff --git a/tensorflow_datasets/core/download/download_manager.py b/tensorflow_datasets/core/download/download_manager.py index a0438c37042..adfa04e1cbe 100644 --- a/tensorflow_datasets/core/download/download_manager.py +++ b/tensorflow_datasets/core/download/download_manager.py @@ -227,7 +227,15 @@ def __init__( self._executor = concurrent.futures.ThreadPoolExecutor(1) def __getstate__(self): - """Remove un-pickleable attributes and return the state.""" + """Remove un-pickleable attributes and return the state. + + Returns: + The state of the download manager. + + Raises: + NotImplementedError: If the register_checksums flag is enabled + in a parallelized download manager. + """ if self._register_checksums: # Currently, checksums registration from Beam not supported. raise NotImplementedError( @@ -241,12 +249,14 @@ def __getstate__(self): @property def _downloader(self): + """Returns the downloader object.""" if not self.__downloader: self.__downloader = downloader.get_downloader() return self.__downloader @property def _extractor(self): + """Returns the extractor object.""" if not self.__extractor: self.__extractor = extractor.get_extractor() return self.__extractor @@ -257,6 +267,15 @@ def downloaded_size(self): return sum(url_info.size for url_info in self._recorded_url_infos.values()) def _get_final_dl_path(self, url, sha256): + """Returns the final download path. + + Args: + url: The download url. + sha256: The sha256 hash hexdump. + + Returns: + The final download path. + """ return os.path.join(self._download_dir, resource_lib.get_dl_fname(url, sha256)) @@ -295,7 +314,9 @@ def _handle_download_result( dst_path: `url_path` (or `file_path` when `register_checksums=True`) Raises: - NonMatchingChecksumError: + ValueError: If the number of files found in the tmp dir is not 1 and if + the checksum was not registered. + NonMatchingChecksumError: If the checksums do not match. """ # Extract the file name, path from the tmp_dir fnames = tf.io.gfile.listdir(tmp_dir_path) @@ -414,7 +435,14 @@ def _find_existing_path(self, url: str, url_path: str) -> Optional[str]: return existing_path def download_checksums(self, checksums_url): - """Downloads checksum file from the given URL and adds it to registry.""" + """Downloads checksum file from the given URL and adds it to registry. + + Args: + checksums_url: The checksum url. + + Returns: + Updated url registry. + """ checksums_path = self.download(checksums_url) with tf.io.gfile.GFile(checksums_path) as f: self._url_infos.update(checksums.parse_url_infos(f)) @@ -475,6 +503,7 @@ def _download(self, resource: Union[str, resource_lib.Resource]): '%s.tmp.%s' % (resource_lib.get_dl_dirname(url), uuid.uuid4().hex)) tf.io.gfile.makedirs(download_dir_path) logging.info('Downloading %s into %s...', url, download_dir_path) + def callback(url_info): return self._handle_download_result( resource=resource, @@ -487,7 +516,14 @@ def callback(url_info): @utils.build_synchronize_decorator() @utils.memoize() def _extract(self, resource): - """Extract a single archive, returns Promise->path to extraction result.""" + """Extract a single archive, returns Promise->path to extraction result. + + Args: + resource: The path to the file to extract. + + Return: + The resolved promise. + """ if isinstance(resource, six.string_types): resource = resource_lib.Resource(path=resource) path = resource.path @@ -506,9 +542,17 @@ def _extract(self, resource): @utils.build_synchronize_decorator() @utils.memoize() def _download_extract(self, resource): - """Download-extract `Resource` or url, returns Promise->path.""" + """Download-extract `Resource` or url, returns Promise->path. + + Args: + resource: The url to download data from. + + Returns: + The resolved promise. + """ if isinstance(resource, six.string_types): resource = resource_lib.Resource(url=resource) + def callback(path): resource.path = path return self._extract(resource) @@ -605,7 +649,14 @@ def download_and_extract(self, url_or_urls): @property def manual_dir(self): - """Returns the directory containing the manually extracted data.""" + """Returns the directory containing the manually extracted data. + + Returns: + The path to the dir containing the manually extracted data. + + Raises: + AssertionError: If the Manual directory does not exist or is empty + """ if not self._manual_dir: raise AssertionError( 'Manual directory was enabled. ' @@ -621,7 +672,19 @@ def manual_dir(self): def _read_url_info(url_path: str) -> checksums.UrlInfo: - """Loads the `UrlInfo` from the `.INFO` file.""" + """Loads the `UrlInfo` from the `.INFO` file. + + Args: + url_path: The path to the .INFO file. + + Returns: + UrlInfo object. + + Raises: + ValueError: If 'url_info' is not found in the .INFO file. + This likely indicates the files were downloaded + with a previous version of TFDS (<=3.1.0). + """ file_info = resource_lib.read_info_file(url_path) if 'url_info' not in file_info: raise ValueError( @@ -636,7 +699,15 @@ def _wait_on_promise(p): def _map_promise(map_fn, all_inputs): - """Map the function into each element and resolve the promise.""" + """Map the function into each element and resolve the promise. + + Args: + map_fn: The function to be mapped into each element. + all_inputs: The elements the function is to be mapped into. + + Returns: + The resolved promise. + """ all_promises = tf.nest.map_structure(map_fn, all_inputs) # Apply the function res = tf.nest.map_structure(_wait_on_promise, all_promises) return res diff --git a/tensorflow_datasets/core/download/download_manager_test.py b/tensorflow_datasets/core/download/download_manager_test.py index fb31087700e..c728655f665 100644 --- a/tensorflow_datasets/core/download/download_manager_test.py +++ b/tensorflow_datasets/core/download/download_manager_test.py @@ -41,10 +41,19 @@ def _sha256(str_): + """Returns the SHA256 hexdump of the given string. + + Args: + str_: String to compute the SHA256 hexdump of. + + Returns: + SHA256 hexdump of the given string. + """ return hashlib.sha256(str_.encode('utf8')).hexdigest() class Artifact(object): + """Artifact class for tracking files used for testing.""" # For testing only. def __init__(self, name, url=None): @@ -86,6 +95,15 @@ def _make_downloader_mock(self): """`downloader.download` patch which creates the returns the path.""" def _download(url, tmpdir_path): + """Download function of the DownloadManager. + + Args: + url: URL to download from. + tmpdir_path: Path to the temporary directory. + + Returns: + The resolved promise. + """ self.downloaded_urls.append(url) # Record downloader.download() calls # If the name isn't explicitly provided, then it is extracted from the # url. @@ -135,6 +153,12 @@ def setUp(self): self.addCleanup(absltest.mock.patch.stopall) def _write_info(self, path, info): + """Writes the content to the .INFO file. + + Args: + path: Path to the .INFO file. + info: Content to be written. + """ content = json.dumps(info) self.fs.add_file(path, content) @@ -146,6 +170,17 @@ def _get_manager( extract_dir='/extract_dir', **kwargs ): + """Returns the DownloadManager object. + + Args: + register_checksums: Whether or not to register the checksums. + url_infos: UrlInfos for the URLs. + dl_dir: Path to the download directory. + extract_dir: Path to the extraction directory. + + Returns: + DownloadManager object. + """ manager = dm.DownloadManager( dataset_name='mnist', download_dir=dl_dir, @@ -257,7 +292,6 @@ def test_download_and_extract_archive_ext_in_fname(self): 'a': '/extract_dir/ZIP.%s' % a.file_name, }) - def test_download_and_extract_already_downloaded(self): a = Artifact('a') # Extract can't be deduced from the url, but from .INFO # File was already downloaded: diff --git a/tensorflow_datasets/core/download/downloader.py b/tensorflow_datasets/core/download/downloader.py index f353ee0e7f7..ceafe18a2ba 100644 --- a/tensorflow_datasets/core/download/downloader.py +++ b/tensorflow_datasets/core/download/downloader.py @@ -46,10 +46,19 @@ @utils.memoize() def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader': + """Returns the _Downloader object.""" return _Downloader(*args, **kwargs) def _get_filename(response: Response) -> str: + """Returns the file name given the response: + + Args: + response: HTTP request response. + + Returns: + File name. + """ content_disposition = response.headers.get('content-disposition', None) if content_disposition: match = re.findall('filename="(.+?)"', content_disposition) @@ -59,6 +68,7 @@ def _get_filename(response: Response) -> str: class DownloadError(Exception): + """Exception class for download errors.""" pass @@ -109,6 +119,15 @@ def download(self, url: str, destination_path: str): def _sync_file_copy( self, filepath: str, destination_path: str) -> checksums_lib.UrlInfo: + """Copy files from source to destination. + + Args: + filepath: Source path. + destination_path: Destination path. + + Returns: + Url checksum. + """ out_path = os.path.join(destination_path, os.path.basename(filepath)) tf.io.gfile.copy(filepath, out_path) hexdigest, size = utils.read_checksum_digest( @@ -176,7 +195,7 @@ def _open_url(url: str) -> ContextManager[Tuple[Response, Iterable[bytes]]]: Returns: response: The url response with `.url` and `.header` attributes. - iter_content: A `bytes` iterator which yield the content. + iter_content: A `bytes` iterator which yields the content. """ # Download FTP urls with `urllib`, otherwise use `requests` open_fn = _open_with_urllib if url.startswith('ftp') else _open_with_requests @@ -185,6 +204,14 @@ def _open_url(url: str) -> ContextManager[Tuple[Response, Iterable[bytes]]]: @contextlib.contextmanager def _open_with_requests(url: str) -> Iterator[Tuple[Response, Iterable[bytes]]]: + """Open url using the requests package. + + Args: + url: Url to open. + + Returns: + Iterator[Tuple[Response, Iterable[bytes]]] + """ with requests.Session() as session: if _DRIVE_URL.match(url): url = _get_drive_url(url, session) @@ -195,6 +222,14 @@ def _open_with_requests(url: str) -> Iterator[Tuple[Response, Iterable[bytes]]]: @contextlib.contextmanager def _open_with_urllib(url: str) -> Iterator[Tuple[Response, Iterable[bytes]]]: + """Open url using the urllib package. + + Args: + url: Url to open. + + Returns: + Iterator[Tuple[Response, Iterable[bytes]]] + """ with urllib.request.urlopen(url) as response: # pytype: disable=attribute-error yield ( response, @@ -203,7 +238,15 @@ def _open_with_urllib(url: str) -> Iterator[Tuple[Response, Iterable[bytes]]]: def _get_drive_url(url: str, session: requests.Session) -> str: - """Returns url, possibly with confirmation token.""" + """Returns the drive url, possibly with confirmation token. + + Args: + url: Drive url. + session: Requests Session object. + + Returns: + Drive url. + """ with session.get(url, stream=True) as response: _assert_status(response) for k, v in response.cookies.items(): @@ -214,7 +257,14 @@ def _get_drive_url(url: str, session: requests.Session) -> str: def _assert_status(response: requests.Response) -> None: - """Ensure the URL response is 200.""" + """Ensure the URL response is 200. + + Args: + response: Requests Response object. + + Raises: + DownloadError: If the response code is not 200. + """ if response.status_code != 200: raise DownloadError('Failed to get url {}. HTTP code: {}.'.format( response.url, response.status_code)) diff --git a/tensorflow_datasets/core/download/downloader_test.py b/tensorflow_datasets/core/download/downloader_test.py index c191db8823c..385b0dff8c3 100644 --- a/tensorflow_datasets/core/download/downloader_test.py +++ b/tensorflow_datasets/core/download/downloader_test.py @@ -33,7 +33,15 @@ class _FakeResponse(object): - + """URL response used for testing. + + Attributes: + url: URL response URL. + content: URL response content. + cookies: URL response cookies. + headers: URL response header. + status_code: URL response status code. + """ def __init__(self, url, content, cookies=None, headers=None, status_code=200): self.url = url self.raw = io.BytesIO(content) @@ -50,12 +58,14 @@ def __exit__(self, *args): return def iter_content(self, chunk_size): + """Iterate over the content of URL response.""" del chunk_size for line in self.raw: yield line class DownloaderTest(testing.TestCase): + """Tests for downloader.py.""" def setUp(self): super(DownloaderTest, self).setUp() @@ -83,6 +93,7 @@ def setUp(self): ).start() def test_ok(self): + """Test download from URL.""" promise = self.downloader.download(self.url, self.tmp_dir) url_info = promise.get() self.assertEqual(url_info.checksum, self.resp_checksum) @@ -91,6 +102,7 @@ def test_ok(self): self.assertFalse(tf.io.gfile.exists(self.incomplete_path)) def test_drive_no_cookies(self): + """Test download from Google Drive without cookies.""" url = 'https://drive.google.com/uc?export=download&id=a1b2bc3' promise = self.downloader.download(url, self.tmp_dir) url_info = promise.get() @@ -100,10 +112,12 @@ def test_drive_no_cookies(self): self.assertFalse(tf.io.gfile.exists(self.incomplete_path)) def test_drive(self): + """Test download from Google Drive with cookies.""" self.cookies = {'foo': 'bar', 'download_warning_a': 'token', 'a': 'b'} self.test_drive_no_cookies() def test_http_error(self): + """Test HTTP file serving error.""" error = downloader.requests.exceptions.HTTPError('Problem serving file.') absltest.mock.patch.object( downloader.requests.Session, 'get', side_effect=error).start() @@ -112,6 +126,7 @@ def test_http_error(self): promise.get() def test_bad_http_status(self): + """Test 404 HTTP status.""" absltest.mock.patch.object( downloader.requests.Session, 'get', @@ -122,6 +137,7 @@ def test_bad_http_status(self): promise.get() def test_ftp(self): + """Test download over FTP.""" url = 'ftp://username:password@example.com/foo.tar.gz' promise = self.downloader.download(url, self.tmp_dir) url_info = promise.get() @@ -131,6 +147,7 @@ def test_ftp(self): self.assertFalse(tf.io.gfile.exists(self.incomplete_path)) def test_ftp_error(self): + """Test download error over FTP.""" error = downloader.urllib.error.URLError('Problem serving file.') absltest.mock.patch.object( downloader.urllib.request, @@ -144,13 +161,16 @@ def test_ftp_error(self): class GetFilenameTest(testing.TestCase): + """Tests to obtain file names.""" def test_no_headers(self): + """Test file name obtained from URL response.""" resp = _FakeResponse('http://foo.bar/baz.zip', b'content') res = downloader._get_filename(resp) self.assertEqual(res, 'baz.zip') def test_headers(self): + """Test file name obtained from URL response using headers.""" cdisp = ('attachment;filename="hello.zip";' 'filename*=UTF-8\'\'hello.zip') resp = _FakeResponse('http://foo.bar/baz.zip', b'content', headers={ diff --git a/tensorflow_datasets/core/download/extractor.py b/tensorflow_datasets/core/download/extractor.py index abf416f59e8..616a23bcf11 100644 --- a/tensorflow_datasets/core/download/extractor.py +++ b/tensorflow_datasets/core/download/extractor.py @@ -47,6 +47,7 @@ @utils.memoize() def get_extractor(*args, **kwargs): + """Returns an _Extractor object.""" return _Extractor(*args, **kwargs) @@ -59,7 +60,11 @@ class UnsafeArchiveError(Exception): class _Extractor(object): - """Singleton (use `get_extractor()` module fct) to extract archives.""" + """Singleton (use `get_extractor()` module fct) to extract archives. + + Attributes: + max_workers: Max number of processes that are executed asynchronously. + """ def __init__(self, max_workers=12): self._executor = concurrent.futures.ThreadPoolExecutor( @@ -75,7 +80,16 @@ def tqdm(self): yield def extract(self, path, extract_method, to_path): - """Returns `promise.Promise` => to_path.""" + """Returns `promise.Promise` => to_path. + + Args: + path: Path to the file to extract from. + extract_method: Extraction method. + to_path: Path to where the file is to be extracted. + + Returns: + The resolved promise. + """ self._pbar_path.update_total(1) if extract_method not in _EXTRACT_METHODS: raise ValueError('Unknown extraction method "%s".' % extract_method) @@ -84,7 +98,19 @@ def extract(self, path, extract_method, to_path): return promise.Promise.resolve(future) def _sync_extract(self, from_path, method, to_path): - """Returns `to_path` once resource has been extracted there.""" + """Returns `to_path` once resource has been extracted there. + + Args: + from_path: The path to the file to be extracted. + method: Extraction method. + to_path: Path to where the file is to be extracted. + + Returns: + Path to where the file was extracted. + + Raise: + ExtractError: If path length > 260 on windows. + """ to_path_tmp = '%s%s_%s' % (to_path, constants.INCOMPLETE_SUFFIX, uuid.uuid4().hex) path = None @@ -102,7 +128,7 @@ def _sync_extract(self, from_path, method, to_path): msg += ( '\n' 'On windows, path lengths greater than 260 characters may ' - 'result in an error. See the doc to remove the limiration: ' + 'result in an error. See the doc to remove the limitation: ' 'https://docs.python.org/3/using/windows.html#removing-the-max-path-limitation' ) raise ExtractError(msg) @@ -116,7 +142,12 @@ def _sync_extract(self, from_path, method, to_path): def _copy(src_file, dest_path): - """Copy data read from src file obj to new file in dest_path.""" + """Copy data read from src file obj to new file in dest_path. + + Args: + src_file: Source file object. + dest_path: Path to copy the source file to. + """ tf.io.gfile.makedirs(os.path.dirname(dest_path)) with tf.io.gfile.GFile(dest_path, 'wb') as dest_file: while True: @@ -127,6 +158,7 @@ def _copy(src_file, dest_path): def _normpath(path): + """Returns the normalized path name.""" path = os.path.normpath(path) if (path.startswith('.') or os.path.isabs(path) @@ -138,6 +170,14 @@ def _normpath(path): @contextlib.contextmanager def _open_or_pass(path_or_fobj): + """Yields the file object given the path or the file object. + + Args: + path_or_fobj: Path to the file or the file object. + + Yields: + File object. + """ if isinstance(path_or_fobj, six.string_types): with tf.io.gfile.GFile(path_or_fobj, 'rb') as f_obj: yield f_obj @@ -175,23 +215,54 @@ def iter_tar(arch_f, stream=False): def iter_tar_stream(arch_f): + """Iterates over the tar file object stream. + + Args: + arch_f: File object of the archive to iterate. + + Yields: + (filepath, extracted_fobj) for each file in the archive. + """ return iter_tar(arch_f, stream=True) def iter_gzip(arch_f): + """Iterates over the zipped file using gzip. + + Args: + arch_f: File object of the archive to iterate. + + Yields: + GzipFile object. + """ with _open_or_pass(arch_f) as fobj: gzip_ = gzip.GzipFile(fileobj=fobj) yield ('', gzip_) # No inner file. def iter_bzip2(arch_f): + """Iterates over the zipped file using bz2. + + Args: + arch_f: File object of the archive to iterate. + + Yields: + BZ2File object. + """ with _open_or_pass(arch_f) as fobj: bz2_ = bz2.BZ2File(filename=fobj) yield ('', bz2_) # No inner file. def iter_zip(arch_f): - """Iterate over zip archive.""" + """Iterate over zip archive. + + Args: + arch_f: File object of the archive to iterate. + + Yields: + (filepath, extracted_fobj) for each file in the archive. + """ with _open_or_pass(arch_f) as fobj: z = zipfile.ZipFile(fobj) for member in z.infolist(): diff --git a/tensorflow_datasets/core/download/extractor_test.py b/tensorflow_datasets/core/download/extractor_test.py index 16343cbcd70..f77c41b1870 100644 --- a/tensorflow_datasets/core/download/extractor_test.py +++ b/tensorflow_datasets/core/download/extractor_test.py @@ -39,14 +39,17 @@ def _read(path): + """Read from the file.""" with tf.io.gfile.GFile(path, 'rb') as f: return f.read() class ExtractorTest(testing.TestCase): + """Tests for extractor.py""" @classmethod def setUpClass(cls): + """Set up the class.""" super(ExtractorTest, cls).setUpClass() f1_path = os.path.join(cls.test_data, '6pixels.png') f2_path = os.path.join(cls.test_data, 'foo.csv') @@ -67,10 +70,18 @@ def setUp(self): self.result_path = os.path.join(self.to_path, '6pixels.png') def test_unknown_method(self): + """Test unknown extraction to raise ValueError.""" with self.assertRaises(ValueError): self.extractor.extract('from/path', NO_EXTRACT, 'to/path') def _test_extract(self, method, archive_name, expected_files): + """Test extraction. + + Args: + method: Extraction method. + archive_name: Name of the archived file to extract. + expected_files: Files expected to be extracted. + """ from_path = os.path.join(self.test_data, 'archives', archive_name) self.extractor.extract(from_path, method, self.to_path).get() for name, content in expected_files.items(): @@ -78,37 +89,44 @@ def _test_extract(self, method, archive_name, expected_files): self.assertEqual(_read(path), content, 'File %s has bad content.' % path) def test_zip(self): + """Test extracting a .zip file.""" self._test_extract( ZIP, 'arch1.zip', {'6pixels.png': self.f1_content, 'foo.csv': self.f2_content}) def test_tar(self): + """Test extracting a .tar file.""" self._test_extract( TAR, 'arch1.tar', {'6pixels.png': self.f1_content, 'foo.csv': self.f2_content}) def test_targz(self): + """Test extracting a .tar.gz file.""" self._test_extract( TAR_GZ, 'arch1.tar.gz', {'6pixels.png': self.f1_content, 'foo.csv': self.f2_content}) def test_tar_stream(self): + """Test extracting a .tar file using the stream.""" self._test_extract( TAR_STREAM, 'arch1.tar', {'6pixels.png': self.f1_content, 'foo.csv': self.f2_content}) def test_targz_stream(self): + """Test extracting a .zip file using the stream.""" self._test_extract( TAR_GZ_STREAM, 'arch1.tar.gz', {'6pixels.png': self.f1_content, 'foo.csv': self.f2_content}) def test_gzip(self): + """Test extracting a .tar.gz archive.""" from_path = os.path.join(self.test_data, 'archives', 'arch1.tar.gz') self.extractor.extract(from_path, GZIP, self.to_path).get() arch1_path = os.path.join(self.test_data, 'archives', 'arch1.tar') self.assertEqual(_read(self.to_path), _read(arch1_path)) def test_gzip2(self): + """Test extracting a .gz archive.""" # Same as previous test, except it is not a .tar.gz, but a .gz. from_path = os.path.join(self.test_data, 'archives', 'foo.csv.gz') self.extractor.extract(from_path, GZIP, self.to_path).get() @@ -116,16 +134,19 @@ def test_gzip2(self): self.assertEqual(_read(self.to_path), _read(foo_csv_path)) def test_bzip2(self): + """Test extracting a .bz2 archive.""" from_path = os.path.join(self.test_data, 'archives', 'foo.csv.bz2') self.extractor.extract(from_path, BZIP2, self.to_path).get() foo_csv_path = os.path.join(self.test_data, 'foo.csv') self.assertEqual(_read(self.to_path), _read(foo_csv_path)) def test_absolute_path(self): + """Test extracting using the absolute path.""" # There is a file with absolute path (ignored) + a file named "foo". self._test_extract(TAR, 'absolute_path.tar', {'foo': b'bar\n'}) def test_wrong_method(self): + """Test extracting file using a wrong extraction method.""" from_path = os.path.join(self.test_data, 'archives', 'foo.csv.gz') promise = self.extractor.extract(from_path, ZIP, self.to_path) expected_msg = 'File is not a zip file' diff --git a/tensorflow_datasets/core/download/kaggle_test.py b/tensorflow_datasets/core/download/kaggle_test.py index a84d3330fb3..a70ac34c8e2 100644 --- a/tensorflow_datasets/core/download/kaggle_test.py +++ b/tensorflow_datasets/core/download/kaggle_test.py @@ -28,8 +28,10 @@ class KaggleTest(testing.TestCase): + """Tests for kaggle.py""" def test_competition_download(self): + """Test downloading a kaggle competition.""" with testing.mock_kaggle_api(): with testing.tmp_dir() as tmp_dir: out_path = kaggle.download_kaggle_data('digit-recognizer', tmp_dir) @@ -38,6 +40,7 @@ def test_competition_download(self): self.assertEqual('digit-recognizer', f.read()) def test_dataset_download(self): + """Test downloading a kaggle dataset.""" with testing.mock_kaggle_api(): with testing.tmp_dir() as tmp_dir: out_path = kaggle.download_kaggle_data('user/dataset', tmp_dir) @@ -46,6 +49,7 @@ def test_dataset_download(self): self.assertEqual('user/dataset', f.read()) def test_competition_download_404(self): + """Test 404 - Not found error using non-existent kaggle competition name.""" with testing.mock_kaggle_api(err_msg='404 - Not found'): with testing.tmp_dir() as tmp_dir: with self.assertRaisesRegex( @@ -53,6 +57,7 @@ def test_competition_download_404(self): kaggle.download_kaggle_data('digit-recognize', tmp_dir) def test_kaggle_type(self): + """Test whether the determined kaggle types are correct.""" self.assertEqual( kaggle._get_kaggle_type('digit-recognizer'), 'competitions' ) diff --git a/tensorflow_datasets/core/download/resource.py b/tensorflow_datasets/core/download/resource.py index dba6c1fa8ec..d1ec59b6fd8 100644 --- a/tensorflow_datasets/core/download/resource.py +++ b/tensorflow_datasets/core/download/resource.py @@ -45,7 +45,14 @@ def _decode_hex(hexstr): - """Returns binary digest, given str hex digest.""" + """Returns binary digest, given str hex digest. + + Args: + hexstr: Hex digest string to decode. + + Returns: + Decoded binary digest. + """ return _hex_codec(hexstr)[0] @@ -96,7 +103,14 @@ class ExtractMethod(enum.Enum): def _guess_extract_method(fname): - """Guess extraction method, given file name (or path).""" + """Guess extraction method, given file name (or path). + + Args: + fname: File name or path to the file. + + Returns: + Method to be used for extraction. + """ for method, extensions in _EXTRACTION_METHOD_TO_EXTS: for ext in extensions: if fname.endswith(ext): @@ -196,18 +210,39 @@ def get_dl_fname(url, checksum): def get_dl_dirname(url): - """Returns name of temp dir for given url.""" + """Returns name of temp dir for given url. + + Args: + url: URL to find the temp dir of. + + Returns: + Name of the temp dir. + """ checksum = hashlib.sha256(tf.compat.as_bytes(url)).hexdigest() return get_dl_fname(url, checksum) def _get_info_path(path): - """Returns path (`str`) of INFO file associated with resource at path.""" + """Returns path (`str`) of INFO file associated with resource at path. + + Args: + path: Path to the resource. + + Returns: + Path to the .INFO file. + """ return '%s.INFO' % path def _read_info(info_path) -> Json: - """Returns info dict or None.""" + """Returns info dict or None. + + Args: + info_path: Path to the .INFO file. + + Returns: + .INFO file content. + """ if not tf.io.gfile.exists(info_path): return None with tf.io.gfile.GFile(info_path) as info_f: @@ -223,12 +258,27 @@ def rename_info_file( dst_path: str, overwrite: bool = False, ) -> None: + """Renames the .INFO file. + + Args: + src_path: Path to the source resource. + dst_path: Path to the destination resource. + overwrite: Whether the dst .INFO file is to be occupied by an existing file. + """ tf.io.gfile.rename( _get_info_path(src_path), _get_info_path(dst_path), overwrite=overwrite) @synchronize_decorator def read_info_file(info_path: str) -> Json: + """Reads the .INFO file. + + Args: + info_path: Path to the .INFO file. + + Returns: + Content of the .INFO file. + """ return _read_info(_get_info_path(info_path)) @@ -280,7 +330,14 @@ def write_info_file( def get_extract_method(path): - """Returns `ExtractMethod` to use on resource at path. Cannot be None.""" + """Returns `ExtractMethod` to use on resource at path. Cannot be None. + + Args: + path: Path to the resource. + + Returns: + The extraction method to be used. + """ info_path = _get_info_path(path) info = _read_info(info_path) fname = info.get('original_fname', path) if info else path @@ -310,7 +367,14 @@ def __init__(self, @classmethod def exists_locally(cls, path): - """Returns whether the resource exists locally, at `resource.path`.""" + """Returns whether the resource exists locally, at `resource.path`. + + Args: + path: Path to the resource. + + Returns: + Whether the resource exists locally. + """ # If INFO file doesn't exist, consider resource does NOT exist, as it would # prevent guessing the `extract_method`. return (tf.io.gfile.exists(path) and @@ -318,7 +382,11 @@ def exists_locally(cls, path): @property def extract_method(self): - """Returns `ExtractMethod` to use on resource. Cannot be None.""" + """Returns `ExtractMethod` to use on resource. Cannot be None. + + Returns: + The extraction method to be used. + """ if self._extract_method: return self._extract_method return get_extract_method(self.path) diff --git a/tensorflow_datasets/core/download/resource_test.py b/tensorflow_datasets/core/download/resource_test.py index 84d25ec9ed6..80dbb58eafa 100644 --- a/tensorflow_datasets/core/download/resource_test.py +++ b/tensorflow_datasets/core/download/resource_test.py @@ -31,8 +31,10 @@ class GuessExtractMethodTest(testing.TestCase): + """Tests to guess the extraction method.""" def test_(self): + """Test if the extraction method correspond to the respective file type.""" for fname, expected_result in [ ('bar.tar.gz', TAR_GZ), ('bar.gz', GZIP), @@ -48,6 +50,8 @@ def test_(self): class DlDirNameTest(testing.TestCase): + """Tests for the download directory.""" + urls = '''\ http://data.statmt.org/wmt17/translation-task/dev.tgz http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz @@ -82,6 +86,7 @@ class DlDirNameTest(testing.TestCase): '''.split('\n') def test_(self): + """Test if the files downloaded correspond to the respective URLs.""" for url, expected in zip(self.urls, self.expected): res = resource.get_dl_dirname(url) self.assertEqual(res, expected)