Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pathlib support to datasets/utils.py #8200

Merged
merged 15 commits into from
Jan 16, 2024
Merged
35 changes: 29 additions & 6 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ def test_get_redirect_url_max_hops_exceeded(self, mocker):
assert mock.call_count == 1
assert mock.call_args[0][0].full_url == url

def test_check_md5(self):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_check_md5(self, use_pathlib):
fpath = TEST_FILE
if use_pathlib:
fpath = pathlib.Path(fpath)
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = ""
assert utils.check_md5(fpath, correct_md5)
Expand Down Expand Up @@ -116,7 +119,8 @@ def test_detect_file_type_incompatible(self, file):
utils._detect_file_type(file)

@pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
def test_decompress(self, extension, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress(self, extension, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
Expand All @@ -128,6 +132,8 @@ def create_compressed(root, content="this is the content"):
return compressed, file, content

compressed, file, content = create_compressed(tmpdir)
if use_pathlib:
compressed = pathlib.Path(compressed)

utils._decompress(compressed)

Expand All @@ -140,7 +146,8 @@ def test_decompress_no_compression(self):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress_remove_finished(self, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
Expand All @@ -151,10 +158,20 @@ def create_compressed(root, content="this is the content"):
return compressed, file, content

compressed, file, content = create_compressed(tmpdir)
print(f"{type(compressed)=}")
if use_pathlib:
compressed = pathlib.Path(compressed)
tmpdir = pathlib.Path(tmpdir)

utils.extract_archive(compressed, tmpdir, remove_finished=True)
extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)

assert not os.path.exists(compressed)
if use_pathlib:
assert isinstance(extracted_dir, pathlib.Path)
assert isinstance(compressed, pathlib.Path)
else:
assert isinstance(extracted_dir, str)
assert isinstance(compressed, str)

@pytest.mark.parametrize("extension", [".gz", ".xz"])
@pytest.mark.parametrize("remove_finished", [True, False])
Expand All @@ -167,7 +184,8 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m

mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_zip(self, tmpdir, use_pathlib):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
Expand All @@ -177,6 +195,8 @@ def create_archive(root, content="this is the content"):

return archive, file, content

if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir)

utils.extract_archive(archive, tmpdir)
Expand All @@ -189,7 +209,8 @@ def create_archive(root, content="this is the content"):
@pytest.mark.parametrize(
"extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
)
def test_extract_tar(self, extension, mode, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
Expand All @@ -203,6 +224,8 @@ def create_archive(root, extension, mode, content="this is the content"):

return archive, dst, content

if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir, extension, mode)

utils.extract_archive(archive, tmpdir)
Expand Down
69 changes: 46 additions & 23 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def _save_response_content(
content: Iterator[bytes],
destination: str,
destination: Union[str, pathlib.Path],
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
Expand All @@ -43,12 +43,12 @@ def _save_response_content(
pbar.update(len(chunk))


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)


def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
Expand All @@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
return md5.hexdigest()


def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)


def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
Expand Down Expand Up @@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
def download_url(
url: str,
root: Union[str, pathlib.Path],
filename: Optional[str] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None:
Expand Down Expand Up @@ -159,7 +159,7 @@ def download_url(
raise RuntimeError("File not found or corrupted.")


def list_dir(root: str, prefix: bool = False) -> List[str]:
def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:
"""List all directories at a given root

Args:
Expand All @@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:
return directories


def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root

Args:
Expand Down Expand Up @@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple


def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
file_id: str,
root: Union[str, pathlib.Path],
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
):
"""Download a Google Drive file from and place it in root.

Expand Down Expand Up @@ -278,7 +281,9 @@ def download_file_from_google_drive(
)


def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_tar(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)

Expand All @@ -289,14 +294,16 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No
}


def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_zip(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)


_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
Expand All @@ -312,7 +319,7 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
}


def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.

Args:
Expand Down Expand Up @@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def _decompress(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> pathlib.Path:
r"""Decompress a file.

The compression is automatically detected from the file name.
Expand All @@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
Expand All @@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
if remove_finished:
os.remove(from_path)

return to_path
return pathlib.Path(to_path)


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def extract_archive(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> Union[str, pathlib.Path]:
"""Extract an archive.

The archive type and a possible compression is automatically detected from the file name. If the file is compressed
Expand All @@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
Returns:
(str): Path to the directory the file was extracted to.
"""

def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
if isinstance(from_path, str):
return os.fspath(ret_path)
else:
return ret_path

if to_path is None:
to_path = os.path.dirname(from_path)

suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
ret_path = _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)
return path_or_str(ret_path)

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]
Expand All @@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
if remove_finished:
os.remove(from_path)

return to_path
return path_or_str(pathlib.Path(to_path))


def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
download_root: Union[str, pathlib.Path],
extract_root: Optional[Union[str, pathlib.Path]] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
Expand Down Expand Up @@ -479,7 +502,7 @@ def verify_str_arg(
return value


def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.

Args:
Expand Down