Skip to content

Commit 2ba586d

Browse files
NicolasHugpmeier
andauthored
Document that datasets support pathlib.Path (#8321)
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
1 parent 0325175 commit 2ba586d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+216
-181
lines changed

torchvision/datasets/_optical_flow.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .utils import _read_pfm, verify_str_arg
1414
from .vision import VisionDataset
1515

16-
1716
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
1817
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
1918

@@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
3332
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
3433
_has_builtin_flow_mask = False
3534

36-
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
35+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
3736

3837
super().__init__(root=root)
3938
self.transforms = transforms
@@ -113,7 +112,7 @@ class Sintel(FlowDataset):
113112
...
114113
115114
Args:
116-
root (string): Root directory of the Sintel Dataset.
115+
root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
117116
split (string, optional): The dataset split, either "train" (default) or "test"
118117
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
119118
details on the different passes.
@@ -125,7 +124,7 @@ class Sintel(FlowDataset):
125124

126125
def __init__(
127126
self,
128-
root: str,
127+
root: Union[str, Path],
129128
split: str = "train",
130129
pass_name: str = "clean",
131130
transforms: Optional[Callable] = None,
@@ -183,15 +182,15 @@ class KittiFlow(FlowDataset):
183182
flow_occ
184183
185184
Args:
186-
root (string): Root directory of the KittiFlow Dataset.
185+
root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
187186
split (string, optional): The dataset split, either "train" (default) or "test"
188187
transforms (callable, optional): A function/transform that takes in
189188
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
190189
"""
191190

192191
_has_builtin_flow_mask = True
193192

194-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
193+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
195194
super().__init__(root=root, transforms=transforms)
196195

197196
verify_str_arg(split, "split", valid_values=("train", "test"))
@@ -248,15 +247,15 @@ class FlyingChairs(FlowDataset):
248247
249248
250249
Args:
251-
root (string): Root directory of the FlyingChairs Dataset.
250+
root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
252251
split (string, optional): The dataset split, either "train" (default) or "val"
253252
transforms (callable, optional): A function/transform that takes in
254253
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
255254
``valid_flow_mask`` is expected for consistency with other datasets which
256255
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
257256
"""
258257

259-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
258+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
260259
super().__init__(root=root, transforms=transforms)
261260

262261
verify_str_arg(split, "split", valid_values=("train", "val"))
@@ -316,7 +315,7 @@ class FlyingThings3D(FlowDataset):
316315
TRAIN
317316
318317
Args:
319-
root (string): Root directory of the intel FlyingThings3D Dataset.
318+
root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
320319
split (string, optional): The dataset split, either "train" (default) or "test"
321320
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
322321
details on the different passes.
@@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):
329328

330329
def __init__(
331330
self,
332-
root: str,
331+
root: Union[str, Path],
333332
split: str = "train",
334333
pass_name: str = "clean",
335334
camera: str = "left",
@@ -411,15 +410,15 @@ class HD1K(FlowDataset):
411410
image_2
412411
413412
Args:
414-
root (string): Root directory of the HD1K Dataset.
413+
root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
415414
split (string, optional): The dataset split, either "train" (default) or "test"
416415
transforms (callable, optional): A function/transform that takes in
417416
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
418417
"""
419418

420419
_has_builtin_flow_mask = True
421420

422-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
421+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
423422
super().__init__(root=root, transforms=transforms)
424423

425424
verify_str_arg(split, "split", valid_values=("train", "test"))

torchvision/datasets/_stereo_matching.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
2727

2828
_has_built_in_disparity_mask = False
2929

30-
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
30+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
3131
"""
3232
Args:
3333
root(str): Root directory of the dataset.
@@ -159,11 +159,11 @@ class CarlaStereo(StereoMatchingDataset):
159159
...
160160
161161
Args:
162-
root (string): Root directory where `carla-highres` is located.
162+
root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
163163
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
164164
"""
165165

166-
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
166+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
167167
super().__init__(root, transforms)
168168

169169
root = Path(root) / "carla-highres"
@@ -233,14 +233,14 @@ class Kitti2012Stereo(StereoMatchingDataset):
233233
calib
234234
235235
Args:
236-
root (string): Root directory where `Kitti2012` is located.
236+
root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
237237
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
238238
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
239239
"""
240240

241241
_has_built_in_disparity_mask = True
242242

243-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
243+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
244244
super().__init__(root, transforms)
245245

246246
verify_str_arg(split, "split", valid_values=("train", "test"))
@@ -321,14 +321,14 @@ class Kitti2015Stereo(StereoMatchingDataset):
321321
calib
322322
323323
Args:
324-
root (string): Root directory where `Kitti2015` is located.
324+
root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
325325
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
326326
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
327327
"""
328328

329329
_has_built_in_disparity_mask = True
330330

331-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
331+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
332332
super().__init__(root, transforms)
333333

334334
verify_str_arg(split, "split", valid_values=("train", "test"))
@@ -420,7 +420,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
420420
...
421421
422422
Args:
423-
root (string): Root directory of the Middleburry 2014 Dataset.
423+
root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
424424
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
425425
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
426426
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
@@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
480480

481481
def __init__(
482482
self,
483-
root: str,
483+
root: Union[str, Path],
484484
split: str = "train",
485485
calibration: Optional[str] = "perfect",
486486
use_ambient_views: bool = False,
@@ -576,7 +576,7 @@ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.n
576576
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
577577
return disparity_map, valid_mask
578578

579-
def _download_dataset(self, root: str) -> None:
579+
def _download_dataset(self, root: Union[str, Path]) -> None:
580580
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
581581
# train and additional splits have 2 different calibration settings
582582
root = Path(root) / "Middlebury2014"
@@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):
675675

676676
def __init__(
677677
self,
678-
root: str,
678+
root: Union[str, Path],
679679
transforms: Optional[Callable] = None,
680680
) -> None:
681681
super().__init__(root, transforms)
@@ -757,12 +757,12 @@ class FallingThingsStereo(StereoMatchingDataset):
757757
...
758758
759759
Args:
760-
root (string): Root directory where FallingThings is located.
760+
root (str or ``pathlib.Path``): Root directory where FallingThings is located.
761761
variant (string): Which variant to use. Either "single", "mixed", or "both".
762762
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
763763
"""
764764

765-
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
765+
def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
766766
super().__init__(root, transforms)
767767

768768
root = Path(root) / "FallingThings"
@@ -868,7 +868,7 @@ class SceneFlowStereo(StereoMatchingDataset):
868868
...
869869
870870
Args:
871-
root (string): Root directory where SceneFlow is located.
871+
root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
872872
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
873873
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
874874
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
@@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):
877877

878878
def __init__(
879879
self,
880-
root: str,
880+
root: Union[str, Path],
881881
variant: str = "FlyingThings3D",
882882
pass_name: str = "clean",
883883
transforms: Optional[Callable] = None,
@@ -973,14 +973,14 @@ class SintelStereo(StereoMatchingDataset):
973973
...
974974
975975
Args:
976-
root (string): Root directory where Sintel Stereo is located.
976+
root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
977977
pass_name (string): The name of the pass to use, either "final", "clean" or "both".
978978
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
979979
"""
980980

981981
_has_built_in_disparity_mask = True
982982

983-
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
983+
def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
984984
super().__init__(root, transforms)
985985

986986
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
@@ -1082,12 +1082,12 @@ class InStereo2k(StereoMatchingDataset):
10821082
...
10831083
10841084
Args:
1085-
root (string): Root directory where InStereo2k is located.
1085+
root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
10861086
split (string): Either "train" or "test".
10871087
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
10881088
"""
10891089

1090-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
1090+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
10911091
super().__init__(root, transforms)
10921092

10931093
root = Path(root) / "InStereo2k" / split
@@ -1169,14 +1169,14 @@ class ETH3DStereo(StereoMatchingDataset):
11691169
...
11701170
11711171
Args:
1172-
root (string): Root directory of the ETH3D Dataset.
1172+
root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
11731173
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
11741174
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
11751175
"""
11761176

11771177
_has_built_in_disparity_mask = True
11781178

1179-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
1179+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
11801180
super().__init__(root, transforms)
11811181

11821182
verify_str_arg(split, "split", valid_values=("train", "test"))

torchvision/datasets/caltech.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import os.path
3+
from pathlib import Path
34
from typing import Any, Callable, List, Optional, Tuple, Union
45

56
from PIL import Image
@@ -16,7 +17,7 @@ class Caltech101(VisionDataset):
1617
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
1718
1819
Args:
19-
root (string): Root directory of dataset where directory
20+
root (str or ``pathlib.Path``): Root directory of dataset where directory
2021
``caltech101`` exists or will be saved to if download is set to True.
2122
target_type (string or list, optional): Type of target to use, ``category`` or
2223
``annotation``. Can also be a list to output a tuple with all specified
@@ -38,7 +39,7 @@ class Caltech101(VisionDataset):
3839

3940
def __init__(
4041
self,
41-
root: str,
42+
root: Union[str, Path],
4243
target_type: Union[List[str], str] = "category",
4344
transform: Optional[Callable] = None,
4445
target_transform: Optional[Callable] = None,
@@ -153,7 +154,7 @@ class Caltech256(VisionDataset):
153154
"""`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
154155
155156
Args:
156-
root (string): Root directory of dataset where directory
157+
root (str or ``pathlib.Path``): Root directory of dataset where directory
157158
``caltech256`` exists or will be saved to if download is set to True.
158159
transform (callable, optional): A function/transform that takes in a PIL image
159160
and returns a transformed version. E.g, ``transforms.RandomCrop``

torchvision/datasets/celeba.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import csv
22
import os
33
from collections import namedtuple
4+
from pathlib import Path
45
from typing import Any, Callable, List, Optional, Tuple, Union
56

67
import PIL
@@ -16,7 +17,7 @@ class CelebA(VisionDataset):
1617
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
1718
1819
Args:
19-
root (string): Root directory where images are downloaded to.
20+
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
2021
split (string): One of {'train', 'valid', 'test', 'all'}.
2122
Accordingly dataset is selected.
2223
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
@@ -63,7 +64,7 @@ class CelebA(VisionDataset):
6364

6465
def __init__(
6566
self,
66-
root: str,
67+
root: Union[str, Path],
6768
split: str = "train",
6869
target_type: Union[List[str], str] = "attr",
6970
transform: Optional[Callable] = None,

torchvision/datasets/cifar.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os.path
22
import pickle
3-
from typing import Any, Callable, Optional, Tuple
3+
from pathlib import Path
4+
from typing import Any, Callable, Optional, Tuple, Union
45

56
import numpy as np
67
from PIL import Image
@@ -13,7 +14,7 @@ class CIFAR10(VisionDataset):
1314
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
1415
1516
Args:
16-
root (string): Root directory of dataset where directory
17+
root (str or ``pathlib.Path``): Root directory of dataset where directory
1718
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
1819
train (bool, optional): If True, creates dataset from training set, otherwise
1920
creates from test set.
@@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):
5051

5152
def __init__(
5253
self,
53-
root: str,
54+
root: Union[str, Path],
5455
train: bool = True,
5556
transform: Optional[Callable] = None,
5657
target_transform: Optional[Callable] = None,

torchvision/datasets/cityscapes.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
from collections import namedtuple
4+
from pathlib import Path
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
56

67
from PIL import Image
@@ -13,7 +14,7 @@ class Cityscapes(VisionDataset):
1314
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
1415
1516
Args:
16-
root (string): Root directory of dataset where directory ``leftImg8bit``
17+
root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
1718
and ``gtFine`` or ``gtCoarse`` are located.
1819
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
1920
otherwise ``train``, ``train_extra`` or ``val``
@@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):
103104

104105
def __init__(
105106
self,
106-
root: str,
107+
root: Union[str, Path],
107108
split: str = "train",
108109
mode: str = "fine",
109110
target_type: Union[List[str], str] = "instance",

0 commit comments

Comments
 (0)