diff --git a/torchvision/datasets/imagenette.py b/torchvision/datasets/imagenette.py index 05da537891b..7994199c818 100644 --- a/torchvision/datasets/imagenette.py +++ b/torchvision/datasets/imagenette.py @@ -4,7 +4,7 @@ from PIL import Image from .folder import find_classes, make_dataset -from .utils import download_and_extract_archive, verify_str_arg +from .utils import check_integrity, download_and_extract_archive, verify_str_arg from .vision import VisionDataset @@ -66,8 +66,8 @@ def __init__( if download: self._download() - elif not self._check_exists(): - raise RuntimeError("Dataset not found. You can use download=True to download it.") + elif not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") self.wnids, self.wnid_to_idx = find_classes(self._image_root) self.classes = [self._WNID_TO_CLASS[wnid] for wnid in self.wnids] @@ -76,15 +76,13 @@ def __init__( } self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg") - def _check_exists(self) -> bool: - return self._size_root.exists() + def _check_integrity(self) -> bool: + return check_integrity(self._size_root.with_suffix(".tgz"), self._md5) def _download(self): - if self._check_exists(): - raise RuntimeError( - f"The directory {self._size_root} already exists. " - f"If you want to re-download or re-extract the images, delete the directory." - ) + if self._check_integrity(): + print("Files already downloaded and verified") + return download_and_extract_archive(self._url, self.root, md5=self._md5)