Skip to content

Commit 5a7b0e8

Browse files
committed
Refactorize yews.datasets
1 parent 278caba commit 5a7b0e8

File tree

5 files changed

+284
-146
lines changed

5 files changed

+284
-146
lines changed

yews/datasets/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
from .classification import ClassificationDataset, DatasetArray, DatasetFolder
1+
from .base import BaseDataset, PathDataset
2+
from .file import FileDataset, DatasetArray
3+
from .dir import DirDataset, DatasetFolder, DatasetArrayFolder
24

35
__all__ = (
4-
'ClassificationDataset',
6+
'BaseDataset',
7+
'PathDataset',
8+
'FileDataset',
9+
'DirDataset',
510
'DatasetArray',
611
'DatasetFolder',
12+
'DatasetArrayFolder',
713
)
814

yews/datasets/base.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from torch.utils import data
2+
3+
4+
class BaseDataset(data.Dataset):
5+
"""An abstract class representing a Dataset.
6+
7+
All other datasets should subclass it. All subclasses should override
8+
``build_dataset`` which construct the dataset-like object from root.
9+
10+
A dataset-like object has both ``__len__`` and ``__getitem__`` implmented.
11+
Typical dataset-like objects include python list and numpy ndarray.
12+
13+
Args:
14+
root (object): Source of the dataset.
15+
sample_transform (callable, optional): A function/transform that takes
16+
a sample and returns a transformed version.
17+
target_transform (callable, optional): A function/transform that takes
18+
a target and transform it.
19+
20+
Attributes:
21+
samples (dataset-like object): Dataset-like object for samples.
22+
targets (dataset-like object): Dataset-like object for targets.
23+
24+
"""
25+
26+
_repr_indent = 4
27+
28+
def __init__(self, root=None, sample_transform=None, target_transform=None):
29+
self.root = root
30+
31+
if self.root is not None:
32+
self.samples, self.targets = self.build_dataset()
33+
34+
if len(samples) == len(targets):
35+
self.size = len(targets)
36+
else:
37+
raise ValueError("Samples and targets have different lengths.")
38+
39+
self.sample_transform = sample_transform
40+
self.target_transform = target_transform
41+
42+
def build_dataset(self):
43+
"""
44+
Returns:
45+
samples (ndarray): List of samples.
46+
labels (ndarray): List of labels.
47+
48+
"""
49+
raise NotImplementedError
50+
51+
def __getitem__(self, index):
52+
sample = self.samples[index]
53+
target = self.targets[index]
54+
55+
if self.sample_transform is not None:
56+
sample = self.sample_transform(sample)
57+
58+
if self.target_transform is not None:
59+
target = transform_transform(target)
60+
61+
return sample, target
62+
63+
def __len__(self):
64+
return self.size
65+
66+
def __repr__(self):
67+
head = "Dataset " + self.__class__.__name__
68+
body = ["Number of datapoints: {}".format(self.__len__())]
69+
if self.root is not None:
70+
body.append("Root location: {}".format(self.root))
71+
body += self.extra_repr().splitlines()
72+
if self.sample_transform is not None:
73+
body += self._format_transform_repr(self.sample_transform,
74+
"Sample transforms: ")
75+
if self.target_transform is not None:
76+
body += self._format_transform_repr(self.target_transform,
77+
"Target transforms: ")
78+
lines = [head] + [" " * self._repr_indent + line for line in body]
79+
return '\n'.join(lines)
80+
81+
def _format_transform_repr(self, transform, head):
82+
lines = transform.__repr__().splitlines()
83+
return (["{}{}".format(head, lines[0])] +
84+
["{}{}".format(" " * len(head), line) for line in lines[1:]])
85+
86+
def extra_repr(self):
87+
return ""
88+
89+
90+
class PathDataset(BaseDataset):
91+
"""An abstract class representing a Dataset defined by a Path.
92+
93+
Args:
94+
root (object): Path to the dataset.
95+
sample_transform (callable, optional): A function/transform that takes
96+
a sample and returns a transformed version.
97+
target_transform (callable, optional): A function/transform that takes
98+
a target and transform it.
99+
100+
Attributes:
101+
samples (list): List of samples in the dataset.
102+
targets (list): List of targets in teh dataset.
103+
104+
"""
105+
106+
def __init__(self, **kwargs):
107+
super(PathDataset, self).__init__(**kwargs)
108+
self.root = Path(self.root).resolve()

yews/datasets/classification.py

-144
This file was deleted.

yews/datasets/dir.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from .base import PathDataset
2+
3+
4+
class DirDataset(PathDataset):
5+
"""An abstract class representing a Dataset in a directory.
6+
7+
Args:
8+
root (object): Directory of the dataset.
9+
sample_transform (callable, optional): A function/transform that takes
10+
a sample and returns a transformed version.
11+
target_transform (callable, optional): A function/transform that takes
12+
a target and transform it.
13+
14+
Attributes:
15+
samples (list): List of samples in the dataset.
16+
targets (list): List of targets in teh dataset.
17+
18+
"""
19+
20+
def __init__(self, **kwargs):
21+
super(DirDataset, self).__init__(**kwargs)
22+
if not self.root.is_dir():
23+
raise ValueError(f"{self.root} is not a directory.")
24+
25+
26+
class DatasetArrayFolder(DirDataset):
27+
"""A generic data loader for a folder of ``.npy`` files where samples are
28+
arranged in the following way: ::
29+
30+
root/samples.npy: each row is a sample
31+
root/targets.npy: each row is a label
32+
33+
where both samples and targets can be arrays.
34+
35+
Args:
36+
root (object): Path to the dataset.
37+
sample_transform (callable, optional): A function/transform that takes
38+
a sample and returns a transformed version.
39+
target_transform (callable, optional): A function/transform that takes
40+
a target and transform it.
41+
42+
Attributes:
43+
samples (list): List of samples in the dataset.
44+
targets (list): List of targets in teh dataset.
45+
46+
"""
47+
48+
def build_dataset(self):
49+
samples = np.load(self.root / 'samples.npy', mmap_mode='r')
50+
targets = np.load(self.root / 'targets.npy', mmap_mode='r')
51+
52+
return samples, targets
53+
54+
55+
class DatasetFolder(DirDataset):
56+
"""A generic data loader for a folder where samples are arranged in the
57+
following way: ::
58+
59+
root/.../class_x.xxx
60+
root/.../class_x.sdf3
61+
root/.../class_x.asd932
62+
63+
root/.../class_y.yyy
64+
root/.../class_y.as4h
65+
root/.../blass_y.jlk2
66+
67+
Args:
68+
root (path): Path to the dataset.
69+
loader (callable): Function that load one sample from a file.
70+
sample_transform (callable, optional): A function/transform that takes
71+
a sample and returns a transformed version.
72+
target_transform (callable, optional): A function/transform that takes
73+
a target and transform it.
74+
75+
Attributes:
76+
samples (list): List of samples in the dataset.
77+
targets (list): List of targets in teh dataset.
78+
79+
80+
"""
81+
82+
class FilesLoader(object):
83+
"""A dataset-like class for loading a list of files given a loader.
84+
85+
Args:
86+
files (list): List of file paths
87+
loader (callable): Function that load one file.
88+
"""
89+
90+
def __init__(self, files, loader):
91+
self.files = files
92+
self.loader = loader
93+
94+
def __getitem__(self, index):
95+
return self.loader(self.file_list[index])
96+
97+
def __len__(self):
98+
return len(file_list)
99+
100+
def __init__(self, loader, **kwargs):
101+
super(DatasetFolder, self).__init__(**kwargs)
102+
self.loader = loader
103+
104+
def make_dataset(self):
105+
files = [p for p in self.root.glob("**/*") if p.is_file()]
106+
labels = [p.name.split('.')[0] for p in files]
107+
samples = self.FilesLoader(files, self.loader)
108+
109+
return samples, labels
110+

0 commit comments

Comments
 (0)