-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
113 lines (97 loc) · 3.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
from torchvision.io import read_image
# ================================================================ #
# Loading a Dataset #
# ================================================================ #
train_data = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root='data',
train=False,
download=True,
transform=ToTensor()
)
# ================================================================ #
# Visualize the Dataset #
# ================================================================ #
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
random_idx = torch.randint(len(train_data), size=(1,)).item()
img, label = train_data[random_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis('off')
plt.imshow(img.squeeze(), cmap='gray')
plt.show()
# ================================================================ #
# Custom Dataset #
# ================================================================ #
''' A custom Dataset class must implement three functions: __init__, __len__, and __getitem__ '''
class CustomDataset(Dataset):
def __init__(self, annotation_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotation_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
# The __len__ function returns the number of samples in our dataset.
def __len__(self):
return len(self.img_labels)
# The __getitem__ function loads and returns a sample from the dataset at the given index idx
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
sample = {"image": image, "label": label}
return sample
# ================================================================ #
# DataLoaders #
# ================================================================ #
train_loader = DataLoader(
dataset=train_data,
batch_size=64,
shuffle=True
)
test_loader = DataLoader(
dataset=test_data,
batch_size=64,
shuffle=False
)
# ================================================================ #
# Iterate through DataLoaders #
# ================================================================ #
# Display image and label
train_features, train_labels = next(iter(train_loader))
print(f'Feature batch shape: {train_features.size()}')
print(f'Lables batch shape: {train_labels.size()}')
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap='gray')
plt.show()
print(f'Label: {label}')