-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathdatasets.py
57 lines (49 loc) · 1.96 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms as T
from torchvision import datasets as dset
MNIST_TRN_TRANSFORM = T.Compose([
T.ToTensor()
])
MNIST_TST_TRANSFORM = T.Compose([
T.ToTensor()
])
CIFAR10_TRN_TRANSFORM = T.Compose([
T.RandomCrop(28),
T.ToTensor()
])
CIFAR10_TST_TRANSFORM = T.Compose([
T.CenterCrop(28),
T.ToTensor()
])
CIFAR10_CLASSES = (
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
def get_mnist_dataset(trn_size=60000, tst_size=10000):
trainset = dset.MNIST(root='./data', train=True,
download=True, transform=MNIST_TRN_TRANSFORM)
trainset.train_data = trainset.train_data[:trn_size]
trainset.train_labels = trainset.train_labels[:trn_size]
testset = dset.MNIST(root='./data', train=False,
download=True, transform=MNIST_TST_TRANSFORM)
testset.test_data = testset.test_data[:tst_size]
testset.test_labels = testset.test_labels[:tst_size]
return trainset, testset
def get_cifar10_dataset(trn_size=60000, tst_size=10000):
trainset = dset.CIFAR10(root='./data', train=True,
download=True, transform=CIFAR10_TRN_TRANSFORM)
trainset.train_data = trainset.train_data[:trn_size]
trainset.train_labels = trainset.train_labels[:trn_size]
testset = dset.CIFAR10(root='./data', train=False,
download=True, transform=CIFAR10_TST_TRANSFORM)
testset.test_data = testset.test_data[:tst_size]
testset.test_labels = testset.test_labels[:tst_size]
return trainset, testset
def get_data_loader(trainset, testset, batch_size=128):
trainloader = DataLoader(trainset, batch_size=batch_size,
shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size,
shuffle=False)
return trainloader, testloader