Skip to content

Commit 30397d9

Browse files
authoredDec 4, 2023
add Imagenette dataset (#8139)
1 parent 3feb502 commit 30397d9

File tree

5 files changed

+143
-0
lines changed

5 files changed

+143
-0
lines changed
 

‎docs/source/datasets.rst

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Image classification
5454
GTSRB
5555
INaturalist
5656
ImageNet
57+
Imagenette
5758
KMNIST
5859
LFWPeople
5960
LSUN

‎test/test_datasets.py

+35
Original file line numberDiff line numberDiff line change
@@ -3377,6 +3377,41 @@ def test_bad_input(self):
33773377
pass
33783378

33793379

3380+
class ImagenetteTestCase(datasets_utils.ImageDatasetTestCase):
3381+
DATASET_CLASS = datasets.Imagenette
3382+
ADDITIONAL_CONFIGS = combinations_grid(split=["train", "val"], size=["full", "320px", "160px"])
3383+
3384+
_WNIDS = [
3385+
"n01440764",
3386+
"n02102040",
3387+
"n02979186",
3388+
"n03000684",
3389+
"n03028079",
3390+
"n03394916",
3391+
"n03417042",
3392+
"n03425413",
3393+
"n03445777",
3394+
"n03888257",
3395+
]
3396+
3397+
def inject_fake_data(self, tmpdir, config):
3398+
archive_root = "imagenette2"
3399+
if config["size"] != "full":
3400+
archive_root += f"-{config['size'].replace('px', '')}"
3401+
image_root = pathlib.Path(tmpdir) / archive_root / config["split"]
3402+
3403+
num_images_per_class = 3
3404+
for wnid in self._WNIDS:
3405+
datasets_utils.create_image_folder(
3406+
root=image_root,
3407+
name=wnid,
3408+
file_name_fn=lambda idx: f"{wnid}_{idx}.JPEG",
3409+
num_examples=num_images_per_class,
3410+
)
3411+
3412+
return num_images_per_class * len(self._WNIDS)
3413+
3414+
33803415
class TestDatasetWrapper:
33813416
def test_unknown_type(self):
33823417
unknown_object = object()

‎torchvision/datasets/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .gtsrb import GTSRB
3131
from .hmdb51 import HMDB51
3232
from .imagenet import ImageNet
33+
from .imagenette import Imagenette
3334
from .inaturalist import INaturalist
3435
from .kinetics import Kinetics
3536
from .kitti import Kitti
@@ -128,6 +129,7 @@
128129
"InStereo2k",
129130
"ETH3DStereo",
130131
"wrap_dataset_for_transforms_v2",
132+
"Imagenette",
131133
)
132134

133135

‎torchvision/datasets/imagenette.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from pathlib import Path
2+
from typing import Any, Callable, Optional, Tuple
3+
4+
from PIL import Image
5+
6+
from .folder import find_classes, make_dataset
7+
from .utils import download_and_extract_archive, verify_str_arg
8+
from .vision import VisionDataset
9+
10+
11+
class Imagenette(VisionDataset):
12+
"""`Imagenette <https://github.com/fastai/imagenette#imagenette-1>`_ image classification dataset.
13+
14+
Args:
15+
root (string): Root directory of the Imagenette dataset.
16+
split (string, optional): The dataset split. Supports ``"train"`` (default), and ``"val"``.
17+
size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``.
18+
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
19+
downloaded archives are not downloaded again.
20+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
21+
version, e.g. ``transforms.RandomCrop``.
22+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
23+
24+
Attributes:
25+
classes (list): List of the class name tuples.
26+
class_to_idx (dict): Dict with items (class name, class index).
27+
wnids (list): List of the WordNet IDs.
28+
wnid_to_idx (dict): Dict with items (WordNet ID, class index).
29+
"""
30+
31+
_ARCHIVES = {
32+
"full": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", "fe2fc210e6bb7c5664d602c3cd71e612"),
33+
"320px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz", "3df6f0d01a2c9592104656642f5e78a3"),
34+
"160px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz", "e793b78cc4c9e9a4ccc0c1155377a412"),
35+
}
36+
_WNID_TO_CLASS = {
37+
"n01440764": ("tench", "Tinca tinca"),
38+
"n02102040": ("English springer", "English springer spaniel"),
39+
"n02979186": ("cassette player",),
40+
"n03000684": ("chain saw", "chainsaw"),
41+
"n03028079": ("church", "church building"),
42+
"n03394916": ("French horn", "horn"),
43+
"n03417042": ("garbage truck", "dustcart"),
44+
"n03425413": ("gas pump", "gasoline pump", "petrol pump", "island dispenser"),
45+
"n03445777": ("golf ball",),
46+
"n03888257": ("parachute", "chute"),
47+
}
48+
49+
def __init__(
50+
self,
51+
root: str,
52+
split: str = "train",
53+
size: str = "full",
54+
download=False,
55+
transform: Optional[Callable] = None,
56+
target_transform: Optional[Callable] = None,
57+
) -> None:
58+
super().__init__(root, transform=transform, target_transform=target_transform)
59+
60+
self._split = verify_str_arg(split, "split", ["train", "val"])
61+
self._size = verify_str_arg(size, "size", ["full", "320px", "160px"])
62+
63+
self._url, self._md5 = self._ARCHIVES[self._size]
64+
self._size_root = Path(self.root) / Path(self._url).stem
65+
self._image_root = str(self._size_root / self._split)
66+
67+
if download:
68+
self._download()
69+
elif not self._check_exists():
70+
raise RuntimeError("Dataset not found. You can use download=True to download it.")
71+
72+
self.wnids, self.wnid_to_idx = find_classes(self._image_root)
73+
self.classes = [self._WNID_TO_CLASS[wnid] for wnid in self.wnids]
74+
self.class_to_idx = {
75+
class_name: idx for wnid, idx in self.wnid_to_idx.items() for class_name in self._WNID_TO_CLASS[wnid]
76+
}
77+
self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg")
78+
79+
def _check_exists(self) -> bool:
80+
return self._size_root.exists()
81+
82+
def _download(self):
83+
if self._check_exists():
84+
raise RuntimeError(
85+
f"The directory {self._size_root} already exists. "
86+
f"If you want to re-download or re-extract the images, delete the directory."
87+
)
88+
89+
download_and_extract_archive(self._url, self.root, md5=self._md5)
90+
91+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
92+
path, label = self._samples[idx]
93+
image = Image.open(path).convert("RGB")
94+
95+
if self.transform is not None:
96+
image = self.transform(image)
97+
98+
if self.target_transform is not None:
99+
label = self.target_transform(label)
100+
101+
return image, label
102+
103+
def __len__(self) -> int:
104+
return len(self._samples)

‎torchvision/tv_tensors/_dataset_wrapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def classification_wrapper_factory(dataset, target_keys):
284284
datasets.GTSRB,
285285
datasets.DatasetFolder,
286286
datasets.ImageFolder,
287+
datasets.Imagenette,
287288
]:
288289
WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
289290

0 commit comments

Comments
 (0)
Please sign in to comment.