-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnn_mnist_setup_stratified.py
executable file
·124 lines (100 loc) · 4.38 KB
/
cnn_mnist_setup_stratified.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression
from sklearn.cross_validation import StratifiedKFold
from sklearn.cross_validation import StratifiedShuffleSplit
import tensorflow
import argparse
import glob
import numpy as np
import re
import matplotlib.pyplot as plt
def remove_adjacent(train, test, runs):
"""
removes runs adjacent to model runs in the test set from training set
:param train: indices of training set
:param test: indices of test set
:param runs: list of model runs
:returns: train, test set (as list of model runs)
"""
set_test = list()
for item in test:
if item-1 in train:
train = train[train != item -1]
if item + 1 in train:
train = train[train != item +1]
set_test.append(runs[item])
set_train = list()
for item in train:
set_train.append(runs[item])
return set_train, set_test
def construct_np_arrays(datdir,runs):
"""
Numpy files of the runs given are loaded and model/label sets are
concatenated
:param runs: list of model runs (e.g. of training set)
:returns: model_data_np (X), labels_np (Y)
"""
model_data_np = np.empty((0,28,28,21),int)
labels_np = np.empty((0,2),int)
for item in runs:
# load model data
cur_model = np.load(item)
model_data_np = np.concatenate((model_data_np,cur_model),axis = 0)
# load matching labels
cur_index = re.match('.*([0-9]{8}_[0-9]{3}.*)',item).group(1)
cur_label = np.load(datdir + 'training_labels_' + cur_index)
labels_np = np.concatenate((labels_np,cur_label),axis = 0)
return model_data_np, labels_np
def main(region):
# Data loading & preprocessing
datdir = '/home/silviar/Dokumente/Training_set/'
model_files = sorted(glob.glob(datdir + 'training_data_*'+ str(region) + '*'))
X,y = construct_np_arrays(datdir,model_files)
print("constructed initial arrays")
y_list = [item[0] for item in y]
print("prepared for stratification")
# do k-folds
#kf = StratifiedKFold(y=y_list,n_folds = 10)
kf = StratifiedShuffleSplit(y_list, n_iter=10, test_size = 0.1)
loop = 1
for train,test in kf:
with tensorflow.Graph().as_default():
print('Performing loop ' + str(loop))
print('preparing data set')
X_train = X[train];Y_train = y[train]
testX = X[test];testY = y[test]
print('Building network')
# Building convolutional network (e.g. mnist tutorial)
network = input_data(shape=[None, 28, 28, 21], name='input')
network = conv_2d(network, 32, 3, activation='relu', regularizer="L2")
#network = conv_2d(network, 32, 3, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation='relu', regularizer="L2")
#network = conv_2d(network, 64, 3, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = fully_connected(network, 128, activation='tanh')
network = dropout(network, 0.8)
network = fully_connected(network, 256, activation='tanh')
network = dropout(network, 0.8)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='adam', learning_rate=0.001,
loss='categorical_crossentropy', name='target')
model = tflearn.DNN(network, tensorboard_verbose=0)
print('Starting training')
# Training
run_id = 'cnn_mnist_' + str(loop) + '_stratified'
model.fit({'input': X_train}, {'target': Y_train}, n_epoch=30,
validation_set=({'input': testX}, {'target': testY}),
snapshot_step=500, show_metric=True, run_id=run_id)
loop += 1
if __name__ == "__main__":
# get Region
p = argparse.ArgumentParser()
p.add_argument("region")
args = p.parse_args()
main(int(args.region))