-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdataset.py
168 lines (143 loc) · 5.98 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import collections
import nltk
import torch
import PIL
import pytorch_lightning as pl
import numpy as np
import torchvision.transforms.functional as F
from pathlib import Path
from torchvision import transforms as T
from random import randint, choice
from torch.utils.data import DataLoader
import utils
def dl_collate_fn(batch):
return torch.stack([row[0] for row in batch]), [row[1] for row in batch]
class TextImageDataset(torch.utils.data.Dataset):
def __init__(self,
folder: str,
image_size=224,
resize_ratio=0.75,
shuffle=False,
all_captions=False,
is_eval=False):
"""Create a text image dataset from a directory with congruent text and image names.
Args:
folder (str): Folder containing images and text files matched by their paths' respective "stem"
image_size (int, optional): The size of outputted images. Defaults to 224.
resize_ratio (float, optional): Minimum percentage of image contained by resize. Defaults to 0.75.
shuffle (bool, optional): Whether or not to have shuffling behavior during sampling. Defaults to False.
all_captions (bool, optional): Whether to return all captions for an example. Defaults to False.
is_eval (bool, optional): Whether this is running in evaluation mode. Defaults to False.
"""
super().__init__()
self.shuffle = shuffle
self.all_captions = all_captions
self.is_eval = is_eval
path = Path(folder)
text_files = [*path.glob('**/*.txt')]
image_files = [
*path.glob('**/*.png'), *path.glob('**/*.jpg'),
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
]
text_files = {text_file.stem: text_file for text_file in text_files}
image_files = {image_file.stem: image_file for image_file in image_files}
keys = (image_files.keys() & text_files.keys())
self.keys = list(keys)
self.text_files = {k: v for k, v in text_files.items() if k in keys}
self.image_files = {k: v for k, v in image_files.items() if k in keys}
self.resize_ratio = resize_ratio
if self.is_eval:
self.image_transform = T.Compose([
T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(image_size),
T.Lambda(self.fix_img),
T.ToTensor(),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
else:
self.image_transform = T.Compose([
T.Lambda(self.fix_img),
T.RandomResizedCrop(image_size,
scale=(self.resize_ratio, 1.),
ratio=(1., 1.)),
T.ToTensor(),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
def __len__(self):
return len(self.keys)
def fix_img(self, img):
return img.convert('RGB') if img.mode != 'RGB' else img
def random_sample(self):
return self.__getitem__(randint(0, self.__len__() - 1))
def sequential_sample(self, ind):
if ind >= self.__len__() - 1:
return self.__getitem__(0)
return self.__getitem__(ind + 1)
def skip_sample(self, ind):
if self.shuffle:
return self.random_sample()
return self.sequential_sample(ind=ind)
def __getitem__(self, ind):
key = self.keys[ind]
text_file = self.text_files[key]
image_file = self.image_files[key]
try:
descriptions = text_file.read_text().split('\n')
except UnicodeDecodeError:
return self.skip_sample(ind)
descriptions = list(filter(lambda t: len(t) > 0, descriptions))
try:
if self.all_captions:
description = descriptions
else:
description = choice(descriptions)
except IndexError as zero_captions_in_file_ex:
print(f"An exception occurred trying to load file {text_file}.")
print(f"Skipping index {ind}")
return self.skip_sample(ind)
try:
image_tensor = self.image_transform(PIL.Image.open(image_file))
except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
print(f"An exception occurred trying to load file {image_file}.")
print(f"Skipping index {ind}")
return self.skip_sample(ind)
# Success
return image_tensor, description
class DataModule(pl.LightningDataModule):
def __init__(self,
train_datadir,
dev_datadir,
batch_size=64,
nworkers=0,
all_captions=False,
preprocess=None):
super().__init__()
self.train_datadir = train_datadir
self.dev_datadir = dev_datadir
self.batch_size = batch_size
self.nworkers = nworkers
self.all_captions = all_captions
self.preprocess = preprocess
def setup(self, stage=None):
if stage == 'fit' or stage is None:
if self.train_datadir:
self.train = TextImageDataset(
folder=self.train_datadir, all_captions=self.all_captions)
self.valid = TextImageDataset(
folder=self.dev_datadir, is_eval=True, all_captions=self.all_captions)
def train_dataloader(self):
return DataLoader(
self.train,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.nworkers,
collate_fn=dl_collate_fn,
pin_memory=True)
def val_dataloader(self):
return DataLoader(
self.valid,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.nworkers,
collate_fn=dl_collate_fn,
pin_memory=True)