-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
123 lines (89 loc) · 3.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import chainer
from chainer.datasets import TupleDataset
def get_mnist(n_train=100, n_test=100, n_dim=1, with_label=True, classes = None):
"""
:param n_train: nr of training examples per class
:param n_test: nr of test examples per class
:param n_dim: 1 or 3 (for convolutional input)
:param with_label: whether or not to also provide labels
:param classes: if not None, then it selects only those classes, e.g. [0, 1]
:return:
"""
train_data, test_data = chainer.datasets.get_mnist(ndim=n_dim, withlabel=with_label)
if not classes:
classes = np.arange(10)
n_classes = len(classes)
if with_label:
for d in range(2):
if d==0:
data = train_data._datasets[0]
labels = train_data._datasets[1]
n = n_train
else:
data = test_data._datasets[0]
labels = test_data._datasets[1]
n = n_test
for i in range(n_classes):
lidx = np.where(labels == classes[i])[0][:n]
if i==0:
idx = lidx
else:
idx = np.hstack([idx,lidx])
L = np.concatenate([i*np.ones(n) for i in np.arange(n_classes)]).astype('int32')
if d==0:
train_data = TupleDataset(data[idx],L)
else:
test_data = TupleDataset(data[idx],L)
else:
tmp1, tmp2 = chainer.datasets.get_mnist(ndim=n_dim,withlabel=True)
for d in range(2):
if d == 0:
data = train_data
labels = tmp1._datasets[1]
n = n_train
else:
data = test_data
labels = tmp2._datasets[1]
n = n_test
for i in range(n_classes):
lidx = np.where(labels == classes[i])[0][:n]
if i == 0:
idx = lidx
else:
idx = np.hstack([idx, lidx])
if d == 0:
train_data = data[idx]
else:
test_data = data[idx]
return train_data, test_data
# Custom iterator
class RandomIterator(object):
"""
Generates random subsets of data
"""
def __init__(self, data, batch_size=1):
"""
Args:
data (TupleDataset):
batch_size (int):
Returns:
list of batches consisting of (input, output) pairs
"""
self.data = data
self.batch_size = batch_size
self.n_batches = len(self.data) // batch_size
def __iter__(self):
self.idx = -1
self._order = np.random.permutation(len(self.data))[:(self.n_batches * self.batch_size)]
return self
def next(self):
self.idx += 1
if self.idx == self.n_batches:
raise StopIteration
i = self.idx * self.batch_size
# handles unlabeled and labeled data
if isinstance(self.data, np.ndarray):
return self.data[self._order[i:(i + self.batch_size)]]
else:
return list(self.data[self._order[i:(i + self.batch_size)]])