|
| 1 | +from torch.utils import data |
| 2 | + |
| 3 | + |
| 4 | +class BaseDataset(data.Dataset): |
| 5 | + """An abstract class representing a Dataset. |
| 6 | +
|
| 7 | + All other datasets should subclass it. All subclasses should override |
| 8 | + ``build_dataset`` which construct the dataset-like object from root. |
| 9 | +
|
| 10 | + A dataset-like object has both ``__len__`` and ``__getitem__`` implmented. |
| 11 | + Typical dataset-like objects include python list and numpy ndarray. |
| 12 | +
|
| 13 | + Args: |
| 14 | + root (object): Source of the dataset. |
| 15 | + sample_transform (callable, optional): A function/transform that takes |
| 16 | + a sample and returns a transformed version. |
| 17 | + target_transform (callable, optional): A function/transform that takes |
| 18 | + a target and transform it. |
| 19 | +
|
| 20 | + Attributes: |
| 21 | + samples (dataset-like object): Dataset-like object for samples. |
| 22 | + targets (dataset-like object): Dataset-like object for targets. |
| 23 | +
|
| 24 | + """ |
| 25 | + |
| 26 | + _repr_indent = 4 |
| 27 | + |
| 28 | + def __init__(self, root=None, sample_transform=None, target_transform=None): |
| 29 | + self.root = root |
| 30 | + |
| 31 | + if self.root is not None: |
| 32 | + self.samples, self.targets = self.build_dataset() |
| 33 | + |
| 34 | + if len(samples) == len(targets): |
| 35 | + self.size = len(targets) |
| 36 | + else: |
| 37 | + raise ValueError("Samples and targets have different lengths.") |
| 38 | + |
| 39 | + self.sample_transform = sample_transform |
| 40 | + self.target_transform = target_transform |
| 41 | + |
| 42 | + def build_dataset(self): |
| 43 | + """ |
| 44 | + Returns: |
| 45 | + samples (ndarray): List of samples. |
| 46 | + labels (ndarray): List of labels. |
| 47 | +
|
| 48 | + """ |
| 49 | + raise NotImplementedError |
| 50 | + |
| 51 | + def __getitem__(self, index): |
| 52 | + sample = self.samples[index] |
| 53 | + target = self.targets[index] |
| 54 | + |
| 55 | + if self.sample_transform is not None: |
| 56 | + sample = self.sample_transform(sample) |
| 57 | + |
| 58 | + if self.target_transform is not None: |
| 59 | + target = transform_transform(target) |
| 60 | + |
| 61 | + return sample, target |
| 62 | + |
| 63 | + def __len__(self): |
| 64 | + return self.size |
| 65 | + |
| 66 | + def __repr__(self): |
| 67 | + head = "Dataset " + self.__class__.__name__ |
| 68 | + body = ["Number of datapoints: {}".format(self.__len__())] |
| 69 | + if self.root is not None: |
| 70 | + body.append("Root location: {}".format(self.root)) |
| 71 | + body += self.extra_repr().splitlines() |
| 72 | + if self.sample_transform is not None: |
| 73 | + body += self._format_transform_repr(self.sample_transform, |
| 74 | + "Sample transforms: ") |
| 75 | + if self.target_transform is not None: |
| 76 | + body += self._format_transform_repr(self.target_transform, |
| 77 | + "Target transforms: ") |
| 78 | + lines = [head] + [" " * self._repr_indent + line for line in body] |
| 79 | + return '\n'.join(lines) |
| 80 | + |
| 81 | + def _format_transform_repr(self, transform, head): |
| 82 | + lines = transform.__repr__().splitlines() |
| 83 | + return (["{}{}".format(head, lines[0])] + |
| 84 | + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) |
| 85 | + |
| 86 | + def extra_repr(self): |
| 87 | + return "" |
| 88 | + |
| 89 | + |
| 90 | +class PathDataset(BaseDataset): |
| 91 | + """An abstract class representing a Dataset defined by a Path. |
| 92 | +
|
| 93 | + Args: |
| 94 | + root (object): Path to the dataset. |
| 95 | + sample_transform (callable, optional): A function/transform that takes |
| 96 | + a sample and returns a transformed version. |
| 97 | + target_transform (callable, optional): A function/transform that takes |
| 98 | + a target and transform it. |
| 99 | +
|
| 100 | + Attributes: |
| 101 | + samples (list): List of samples in the dataset. |
| 102 | + targets (list): List of targets in teh dataset. |
| 103 | +
|
| 104 | + """ |
| 105 | + |
| 106 | + def __init__(self, **kwargs): |
| 107 | + super(PathDataset, self).__init__(**kwargs) |
| 108 | + self.root = Path(self.root).resolve() |
0 commit comments