-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathbasic_model.py
143 lines (113 loc) · 4.35 KB
/
basic_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import sys
import json
from pathlib import Path
import numpy as np
import keras
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
# global static variables
dtype_mult = 255.0 # unit8
num_classes = 10
X_shape = (-1, 32, 32, 3)
epoch = 200
batch_size = 128
def get_dataset():
sys.stdout.write('Loading Dataset\n')
sys.stdout.flush()
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
return X_train, y_train, X_test, y_test
def get_preprocessed_dataset():
X_train, y_train, X_test, y_test = get_dataset()
sys.stdout.write('Preprocessing Dataset\n\n')
sys.stdout.flush()
X_train = X_train.astype('float32') / dtype_mult
X_test = X_test.astype('float32') / dtype_mult
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
return X_train, y_train, X_test, y_test
def generate_optimizer():
return keras.optimizers.Adam()
def compile_model(model):
model.compile(loss='categorical_crossentropy',
optimizer=generate_optimizer(),
metrics=['accuracy'])
def generate_model():
# check if model exists if exists then load model from saved state
if Path('./models/convnet_model.json').is_file():
sys.stdout.write('Loading existing model\n\n')
sys.stdout.flush()
with open('./models/convnet_model.json') as file:
model = keras.models.model_from_json(json.load(file))
file.close()
# likewise for model weight, if exists load from saved state
if Path('./models/convnet_weights.h5').is_file():
model.load_weights('./models/convnet_weights.h5')
compile_model(model)
return model
sys.stdout.write('Loading new model\n\n')
sys.stdout.flush()
model = Sequential()
# Conv1 32 32 (3) => 30 30 (32)
model.add(Conv2D(32, (3, 3), input_shape=X_shape[1:]))
model.add(Activation('relu'))
# Conv2 30 30 (32) => 28 28 (32)
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
# Pool1 28 28 (32) => 14 14 (32)
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
# Conv3 14 14 (32) => 12 12 (64)
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
# Conv4 12 12 (64) => 6 6 (64)
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
# Pool2 6 6 (64) => 3 3 (64)
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
# FC layers 3 3 (64) => 576
model.add(Flatten())
# Dense1 576 => 256
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dropout(0.5))
# Dense2 256 => 10
model.add(Dense(num_classes))
model.add(Activation('softmax'))
# compile has to be done impurely
compile_model(model)
with open('./models/convnet_model.json', 'w') as outfile:
json.dump(model.to_json(), outfile)
outfile.close()
return model
def train(model, X_train, y_train, X_test, y_test):
sys.stdout.write('Training model\n\n')
sys.stdout.flush()
# train each iteration individually to back up current state
# safety measure against potential crashes
epoch_count = 0
while epoch_count < epoch:
epoch_count += 1
sys.stdout.write('Epoch count: ' + str(epoch_count) + '\n')
sys.stdout.flush()
model.fit(X_train, y_train, batch_size=batch_size,
nb_epoch=1, validation_data=(X_test, y_test))
sys.stdout.write('Epoch {} done, saving model to file\n\n'.format(epoch_count))
sys.stdout.flush()
model.save_weights('./models/convnet_weights.h5')
return model
def get_accuracy(pred, real):
# reward algorithm
result = pred.argmax(axis=1) == real.argmax(axis=1)
return np.sum(result) / len(result)
def main():
sys.stdout.write('Welcome to CIFAR-10 Hello world of CONVNET!\n\n')
sys.stdout.flush()
X_train, y_train, X_test, y_test = get_preprocessed_dataset()
model = generate_model()
model = train(model, X_train, y_train, X_test, y_test)
if __name__ == "__main__":
# execute only if run as a script
main()