Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added missing typing annotations in datasets/_stereo_matching #6846

Merged
merged 12 commits into from
Dec 21, 2022
79 changes: 43 additions & 36 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset

T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice way to keep things simple

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I wonder if we can make the name more expressive though. Not blocking, but T1 and T2 are rather opaque.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but I tried to think about a meaning naming and honestly I didn't manage to find any 😅

T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]

__all__ = ()

_read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
Expand All @@ -24,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):

_has_built_in_disparity_mask = False

def __init__(self, root: str, transforms: Optional[Callable] = None):
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
"""
Args:
root(str): Root directory of the dataset.
Expand Down Expand Up @@ -58,7 +61,11 @@ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
img = img.convert("RGB")
return img

def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None):
def _scan_pairs(
self,
paths_left_pattern: str,
paths_right_pattern: Optional[str] = None,
) -> List[Tuple[str, Optional[str]]]:

left_paths = list(sorted(glob(paths_left_pattern)))

Expand All @@ -85,11 +92,11 @@ def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str
return paths

@abstractmethod
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
# function that returns a disparity map and an occlusion map
pass

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -120,7 +127,7 @@ def __getitem__(self, index: int) -> Tuple:
) = self.transforms(imgs, dsp_maps, valid_masks)

if self._has_built_in_disparity_mask or valid_masks[0] is not None:
return imgs[0], imgs[1], dsp_maps[0], valid_masks[0]
return imgs[0], imgs[1], dsp_maps[0], valid_masks[0] # type: ignore[return-value]
else:
return imgs[0], imgs[1], dsp_maps[0]

Expand Down Expand Up @@ -156,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, transforms: Optional[Callable] = None):
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "carla-highres"
Expand All @@ -171,13 +178,13 @@ def __init__(self, root: str, transforms: Optional[Callable] = None):
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities = disparities

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -233,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand All @@ -250,7 +257,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
else:
self._disparities = list((None, None) for _ in self._images)

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
# test split has no disparity maps
if file_path is None:
return None, None
Expand All @@ -261,7 +268,7 @@ def _read_disparity(self, file_path: str) -> Tuple:
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -321,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand All @@ -338,7 +345,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
else:
self._disparities = list((None, None) for _ in self._images)

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
# test split has no disparity maps
if file_path is None:
return None, None
Expand All @@ -349,7 +356,7 @@ def _read_disparity(self, file_path: str) -> Tuple:
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -479,7 +486,7 @@ def __init__(
use_ambient_views: bool = False,
transforms: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
Expand Down Expand Up @@ -558,7 +565,7 @@ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
file_path = random.choice(ambient_file_paths) # type: ignore
return super()._read_img(file_path)

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
# test split has not disparity maps
if file_path is None:
return None, None
Expand All @@ -569,7 +576,7 @@ def _read_disparity(self, file_path: str) -> Tuple:
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask

def _download_dataset(self, root: str):
def _download_dataset(self, root: str) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014"
Expand Down Expand Up @@ -608,7 +615,7 @@ def _download_dataset(self, root: str):
# cleanup MiddEval3 directory
shutil.rmtree(str(root / "MiddEval3"))

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -670,7 +677,7 @@ def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
):
) -> None:
super().__init__(root, transforms)

root = Path(root) / "CREStereo"
Expand All @@ -688,14 +695,14 @@ def __init__(
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities += disparities

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :] / 32.0
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -755,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None):
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "FallingThings"
Expand All @@ -782,7 +789,7 @@ def __init__(self, root: str, variant: str = "single", transforms: Optional[Call
right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
# (H, W) image
depth = np.asarray(Image.open(file_path))
# as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
Expand All @@ -799,7 +806,7 @@ def _read_disparity(self, file_path: str) -> Tuple:
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -874,7 +881,7 @@ def __init__(
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
):
) -> None:
super().__init__(root, transforms)

root = Path(root) / "SceneFlow"
Expand Down Expand Up @@ -905,13 +912,13 @@ def __init__(
right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -973,7 +980,7 @@ class SintelStereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None):
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
Expand Down Expand Up @@ -1014,7 +1021,7 @@ def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]:

return occlusion_path, outofframe_path

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
if file_path is None:
return None, None

Expand All @@ -1034,7 +1041,7 @@ def _read_disparity(self, file_path: str) -> Tuple:
valid_mask = np.logical_and(off_mask, valid_mask)
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -1080,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "InStereo2k" / split
Expand All @@ -1095,14 +1102,14 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
right_disparity_pattern = str(root / "*" / "right_disp.png")
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze disparity to (C, H, W)
disparity_map = disparity_map[None, :, :] / 1024.0
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down Expand Up @@ -1169,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand All @@ -1189,7 +1196,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
self._disparities = self._scan_pairs(disparity_pattern, None)

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
# test split has no disparity maps
if file_path is None:
return None, None
Expand All @@ -1201,7 +1208,7 @@ def _read_disparity(self, file_path: str) -> Tuple:
valid_mask = np.asarray(valid_mask).astype(bool)
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.

Args:
Expand Down