-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathiter_for.py
29 lines (27 loc) · 903 Bytes
/
iter_for.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def batch_iter(iter, batch_size):
# generating batch according batch_size
n_iter = []
batch_d = 0
for batch_u in range(0, len(iter), batch_size):
if batch_u != 0:
n_iter.append(iter[batch_d:batch_u])
batch_d = batch_u
n_iter.append(iter[batch_d: len(iter) + 1])
return n_iter
def batch_iter2(iter, num_batch):
# generating batch according num_batch
n_iter = []
idx = []
num_iter = len(iter)
batch_size = num_iter // num_batch
batch_d = 0
for batch_u in range(0, num_iter + batch_size, batch_size):
if batch_u != 0:
n_iter.append(iter[batch_d:batch_u])
idx.append((batch_d, batch_u))
batch_d = batch_u
if len(n_iter) > num_batch:
final_v = [_ for v in n_iter[num_batch - 1:] for _ in v]
n_iter = n_iter[:num_batch]
n_iter[-1] = final_v
return n_iter