Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added missing typing annotations in datasets/_stereo_matching #6846

Merged
merged 12 commits into from
Dec 21, 2022
32 changes: 18 additions & 14 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class StereoMatchingDataset(ABC, VisionDataset):

_has_built_in_disparity_mask = False

def __init__(self, root: str, transforms: Optional[Callable] = None):
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
"""
Args:
root(str): Root directory of the dataset.
Expand Down Expand Up @@ -58,7 +58,11 @@ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
img = img.convert("RGB")
return img

def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None):
def _scan_pairs(
self,
paths_left_pattern: str,
paths_right_pattern: Optional[str] = None,
) -> List[Tuple[str, Optional[str]]]:

left_paths = list(sorted(glob(paths_left_pattern)))

Expand Down Expand Up @@ -156,7 +160,7 @@ class CarlaStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, transforms: Optional[Callable] = None):
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "carla-highres"
Expand Down Expand Up @@ -233,7 +237,7 @@ class Kitti2012Stereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -321,7 +325,7 @@ class Kitti2015Stereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -479,7 +483,7 @@ def __init__(
use_ambient_views: bool = False,
transforms: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
Expand Down Expand Up @@ -569,7 +573,7 @@ def _read_disparity(self, file_path: str) -> Tuple:
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask

def _download_dataset(self, root: str):
def _download_dataset(self, root: str) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014"
Expand Down Expand Up @@ -670,7 +674,7 @@ def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
):
) -> None:
super().__init__(root, transforms)

root = Path(root) / "CREStereo"
Expand Down Expand Up @@ -755,7 +759,7 @@ class FallingThingsStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None):
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "FallingThings"
Expand Down Expand Up @@ -874,7 +878,7 @@ def __init__(
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
):
) -> None:
super().__init__(root, transforms)

root = Path(root) / "SceneFlow"
Expand Down Expand Up @@ -973,7 +977,7 @@ class SintelStereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None):
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
Expand Down Expand Up @@ -1080,7 +1084,7 @@ class InStereo2k(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "InStereo2k" / split
Expand All @@ -1095,7 +1099,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl
right_disparity_pattern = str(root / "*" / "right_disp.png")
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze disparity to (C, H, W)
disparity_map = disparity_map[None, :, :] / 1024.0
Expand Down Expand Up @@ -1169,7 +1173,7 @@ class ETH3DStereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down