Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pathlib support to datasets/utils.py #8215

Closed
wants to merge 9 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
@@ -58,8 +58,11 @@ def test_get_redirect_url_max_hops_exceeded(self, mocker):
assert mock.call_count == 1
assert mock.call_args[0][0].full_url == url

def test_check_md5(self):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_check_md5(self, use_pathlib):
fpath = TEST_FILE
if use_pathlib:
fpath = pathlib.Path(fpath)
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = ""
assert utils.check_md5(fpath, correct_md5)
@@ -116,7 +119,8 @@ def test_detect_file_type_incompatible(self, file):
utils._detect_file_type(file)

@pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
def test_decompress(self, extension, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress(self, extension, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
@@ -128,6 +132,8 @@ def create_compressed(root, content="this is the content"):
return compressed, file, content

compressed, file, content = create_compressed(tmpdir)
if use_pathlib:
compressed = pathlib.Path(compressed)

utils._decompress(compressed)

@@ -140,7 +146,8 @@ def test_decompress_no_compression(self):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress_remove_finished(self, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
@@ -151,10 +158,20 @@ def create_compressed(root, content="this is the content"):
return compressed, file, content

compressed, file, content = create_compressed(tmpdir)
print(f"{type(compressed)=}")
if use_pathlib:
compressed = pathlib.Path(compressed)
tmpdir = pathlib.Path(tmpdir)

utils.extract_archive(compressed, tmpdir, remove_finished=True)
extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)

assert not os.path.exists(compressed)
if use_pathlib:
assert isinstance(extracted_dir, pathlib.Path)
assert isinstance(compressed, pathlib.Path)
else:
assert isinstance(extracted_dir, str)
assert isinstance(compressed, str)

@pytest.mark.parametrize("extension", [".gz", ".xz"])
@pytest.mark.parametrize("remove_finished", [True, False])
@@ -167,7 +184,8 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m

mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_zip(self, tmpdir, use_pathlib):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
@@ -177,6 +195,8 @@ def create_archive(root, content="this is the content"):

return archive, file, content

if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir)

utils.extract_archive(archive, tmpdir)
@@ -189,7 +209,8 @@ def create_archive(root, content="this is the content"):
@pytest.mark.parametrize(
"extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
)
def test_extract_tar(self, extension, mode, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
@@ -203,6 +224,8 @@ def create_archive(root, extension, mode, content="this is the content"):

return archive, dst, content

if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir, extension, mode)

utils.extract_archive(archive, tmpdir)
69 changes: 46 additions & 23 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@

def _save_response_content(
content: Iterator[bytes],
destination: str,
destination: Union[str, pathlib.Path],
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
@@ -43,12 +43,12 @@ def _save_response_content(
pbar.update(len(chunk))


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)


def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
@@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
return md5.hexdigest()


def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)


def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
@@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
def download_url(
url: str,
root: Union[str, pathlib.Path],
filename: Optional[str] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None:
@@ -159,7 +159,7 @@ def download_url(
raise RuntimeError("File not found or corrupted.")


def list_dir(root: str, prefix: bool = False) -> List[str]:
def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:
"""List all directories at a given root

Args:
@@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:
return directories


def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root

Args:
@@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple


def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
file_id: str,
root: Union[str, pathlib.Path],
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
):
"""Download a Google Drive file from and place it in root.

@@ -278,7 +281,9 @@ def download_file_from_google_drive(
)


def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_tar(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)

@@ -289,14 +294,16 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No
}


def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_zip(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)


_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
@@ -312,7 +319,7 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
}


def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.

Args:
@@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def _decompress(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> pathlib.Path:
r"""Decompress a file.

The compression is automatically detected from the file name.
@@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
@@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
if remove_finished:
os.remove(from_path)

return to_path
return pathlib.Path(to_path)


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def extract_archive(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> Union[str, pathlib.Path]:
"""Extract an archive.

The archive type and a possible compression is automatically detected from the file name. If the file is compressed
@@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
Returns:
(str): Path to the directory the file was extracted to.
"""

def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
if isinstance(from_path, str):
return os.fspath(ret_path)
else:
return ret_path

if to_path is None:
to_path = os.path.dirname(from_path)

suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
ret_path = _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)
return path_or_str(ret_path)

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]
@@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
if remove_finished:
os.remove(from_path)

return to_path
return path_or_str(pathlib.Path(to_path))


def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
download_root: Union[str, pathlib.Path],
extract_root: Optional[Union[str, pathlib.Path]] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
@@ -479,7 +502,7 @@ def verify_str_arg(
return value


def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.

Args: