Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating Autoencoder to TF 2.0 #6795

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions research/autoencoder/AutoencoderRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow.keras.layers as layers

from autoencoder_models.Autoencoder import Autoencoder

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
mnist = tf.keras.datasets.mnist


def standard_scale(X_train, X_test):
Expand All @@ -24,32 +24,40 @@ def get_random_block_from_data(data, batch_size):
return data[start_index:(start_index + batch_size)]


X_train, X_test = standard_scale(mnist.train.images, mnist.test.images)
(X_train, _), (X_test, _) = mnist.load_data()
X_train = tf.cast(np.reshape(X_train, (X_train.shape[0], X_train.shape[1] * X_train.shape[2])), tf.float64)
X_test = tf.cast(np.reshape(X_test, (X_test.shape[0], X_test.shape[1] * X_test.shape[2])), tf.float64)

n_samples = int(mnist.train.num_examples)
X_train, X_test = standard_scale(X_train, X_test)

train_data = tf.data.Dataset.from_tensor_slices(X_train).batch(128).shuffle(buffer_size=1024)
test_data = tf.data.Dataset.from_tensor_slices(X_test).batch(128).shuffle(buffer_size=512)

n_samples = int(len(X_train) + len(X_test))
training_epochs = 20
batch_size = 128
display_step = 1

autoencoder = Autoencoder(n_layers=[784, 200],
transfer_function = tf.nn.softplus,
optimizer = tf.train.AdamOptimizer(learning_rate = 0.001))
optimizer = tf.optimizers.Adam(learning_rate=0.01)
mse_loss = tf.keras.losses.MeanSquaredError()
loss_metric = tf.keras.metrics.Mean()

autoencoder = Autoencoder([200, 394, 784])

# Iterate over epochs.
for epoch in range(10):
print(f'Epoch {epoch+1}')

for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(n_samples / batch_size)
# Loop over all batches
for i in range(total_batch):
batch_xs = get_random_block_from_data(X_train, batch_size)
# Iterate over the batches of the dataset.
for step, x_batch in enumerate(train_data):
with tf.GradientTape() as tape:
recon = autoencoder(x_batch)
loss = mse_loss(x_batch, recon)

# Fit training using batch data
cost = autoencoder.partial_fit(batch_xs)
# Compute average loss
avg_cost += cost / n_samples * batch_size
grads = tape.gradient(loss, autoencoder.trainable_variables)
optimizer.apply_gradients(zip(grads, autoencoder.trainable_variables))

# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%d,' % (epoch + 1),
"Cost:", "{:.9f}".format(avg_cost))
loss_metric(loss)

print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))
if step % 100 == 0:
print(f'Step {step}: mean loss = {loss_metric.result()}')
140 changes: 52 additions & 88 deletions research/autoencoder/autoencoder_models/Autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,55 @@
import numpy as np
import tensorflow as tf


class Autoencoder(object):

def __init__(self, n_layers, transfer_function=tf.nn.softplus, optimizer=tf.train.AdamOptimizer()):
self.n_layers = n_layers
self.transfer = transfer_function

network_weights = self._initialize_weights()
self.weights = network_weights

# model
self.x = tf.placeholder(tf.float32, [None, self.n_layers[0]])
self.hidden_encode = []
h = self.x
for layer in range(len(self.n_layers)-1):
h = self.transfer(
tf.add(tf.matmul(h, self.weights['encode'][layer]['w']),
self.weights['encode'][layer]['b']))
self.hidden_encode.append(h)

self.hidden_recon = []
for layer in range(len(self.n_layers)-1):
h = self.transfer(
tf.add(tf.matmul(h, self.weights['recon'][layer]['w']),
self.weights['recon'][layer]['b']))
self.hidden_recon.append(h)
self.reconstruction = self.hidden_recon[-1]

# cost
self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.reconstruction, self.x), 2.0))
self.optimizer = optimizer.minimize(self.cost)

init = tf.global_variables_initializer()
self.sess = tf.Session()
self.sess.run(init)


def _initialize_weights(self):
all_weights = dict()
initializer = tf.contrib.layers.xavier_initializer()
# Encoding network weights
encoder_weights = []
for layer in range(len(self.n_layers)-1):
w = tf.Variable(
initializer((self.n_layers[layer], self.n_layers[layer + 1]),
dtype=tf.float32))
b = tf.Variable(
tf.zeros([self.n_layers[layer + 1]], dtype=tf.float32))
encoder_weights.append({'w': w, 'b': b})
# Recon network weights
recon_weights = []
for layer in range(len(self.n_layers)-1, 0, -1):
w = tf.Variable(
initializer((self.n_layers[layer], self.n_layers[layer - 1]),
dtype=tf.float32))
b = tf.Variable(
tf.zeros([self.n_layers[layer - 1]], dtype=tf.float32))
recon_weights.append({'w': w, 'b': b})
all_weights['encode'] = encoder_weights
all_weights['recon'] = recon_weights
return all_weights

def partial_fit(self, X):
cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict={self.x: X})
return cost

def calc_total_cost(self, X):
return self.sess.run(self.cost, feed_dict={self.x: X})

def transform(self, X):
return self.sess.run(self.hidden_encode[-1], feed_dict={self.x: X})

def generate(self, hidden=None):
if hidden is None:
hidden = np.random.normal(size=self.weights['encode'][-1]['b'])
return self.sess.run(self.reconstruction, feed_dict={self.hidden_encode[-1]: hidden})

def reconstruct(self, X):
return self.sess.run(self.reconstruction, feed_dict={self.x: X})

def getWeights(self):
raise NotImplementedError
return self.sess.run(self.weights)

def getBiases(self):
raise NotImplementedError
return self.sess.run(self.weights)

class Encoder(tf.keras.layers.Layer):
'''Encodes a digit from the MNIST dataset'''

def __init__(self,
n_dims,
name='encoder',
**kwargs):
super(Encoder,self).__init__(name=name, **kwargs)
self.n_dims = n_dims
self.n_layers = 1
self.encode_layer = layers.Dense(n_dims, activation='relu')

@tf.function
def call(self, inputs):
return self.encode_layer(inputs)

class Decoder(tf.keras.layers.Layer):
'''Decodes a digit from the MNIST dataset'''

def __init__(self,
n_dims,
name='decoder',
**kwargs):
super(Decoder,self).__init__(name=name, **kwargs)
self.n_dims = n_dims
self.n_layers = len(n_dims)
self.decode_middle = layers.Dense(n_dims[0], activation='relu')
self.recon_layer = layers.Dense(n_dims[1], activation='sigmoid')

@tf.function
def call(self, inputs):
x = self.decode_middle(inputs)
return self.recon_layer(x)



class Autoencoder(tf.keras.Model):
'''Vanilla Autoencoder for MNIST digits'''

def __init__(self,
n_dims=[200, 392, 784],
name='autoencoder',
**kwargs):
super(Autoencoder, self).__init__(name=name, **kwargs)
self.n_dims = n_dims
self.encoder = Encoder(n_dims[0])
self.decoder = Decoder([n_dims[1], n_dims[2]])

@tf.function
def call(self, inputs):
x = self.encoder(inputs)
return self.decoder(x)