-
Notifications
You must be signed in to change notification settings - Fork 18.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LSTMLayer and LSTMUnitLayer, with tests
- Loading branch information
1 parent
cf5f369
commit 51a68f0
Showing
5 changed files
with
971 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
#ifndef CAFFE_LSTM_LAYER_HPP_ | ||
#define CAFFE_LSTM_LAYER_HPP_ | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "caffe/blob.hpp" | ||
#include "caffe/common.hpp" | ||
#include "caffe/layer.hpp" | ||
#include "caffe/layers/recurrent_layer.hpp" | ||
#include "caffe/net.hpp" | ||
#include "caffe/proto/caffe.pb.h" | ||
|
||
namespace caffe { | ||
|
||
template <typename Dtype> class RecurrentLayer; | ||
|
||
/** | ||
* @brief Processes sequential inputs using a "Long Short-Term Memory" (LSTM) | ||
* [1] style recurrent neural network (RNN). Implemented by unrolling | ||
* the LSTM computation through time. | ||
* | ||
* The specific architecture used in this implementation is as described in | ||
* "Learning to Execute" [2], reproduced below: | ||
* i_t := \sigmoid[ W_{hi} * h_{t-1} + W_{xi} * x_t + b_i ] | ||
* f_t := \sigmoid[ W_{hf} * h_{t-1} + W_{xf} * x_t + b_f ] | ||
* o_t := \sigmoid[ W_{ho} * h_{t-1} + W_{xo} * x_t + b_o ] | ||
* g_t := \tanh[ W_{hg} * h_{t-1} + W_{xg} * x_t + b_g ] | ||
* c_t := (f_t .* c_{t-1}) + (i_t .* g_t) | ||
* h_t := o_t .* \tanh[c_t] | ||
* In the implementation, the i, f, o, and g computations are performed as a | ||
* single inner product. | ||
* | ||
* Notably, this implementation lacks the "diagonal" gates, as used in the | ||
* LSTM architectures described by Alex Graves [3] and others. | ||
* | ||
* [1] Hochreiter, Sepp, and Schmidhuber, Jürgen. "Long short-term memory." | ||
* Neural Computation 9, no. 8 (1997): 1735-1780. | ||
* | ||
* [2] Zaremba, Wojciech, and Sutskever, Ilya. "Learning to execute." | ||
* arXiv preprint arXiv:1410.4615 (2014). | ||
* | ||
* [3] Graves, Alex. "Generating sequences with recurrent neural networks." | ||
* arXiv preprint arXiv:1308.0850 (2013). | ||
*/ | ||
template <typename Dtype> | ||
class LSTMLayer : public RecurrentLayer<Dtype> { | ||
public: | ||
explicit LSTMLayer(const LayerParameter& param) | ||
: RecurrentLayer<Dtype>(param) {} | ||
|
||
virtual inline const char* type() const { return "LSTM"; } | ||
|
||
protected: | ||
virtual void FillUnrolledNet(NetParameter* net_param) const; | ||
virtual void RecurrentInputBlobNames(vector<string>* names) const; | ||
virtual void RecurrentOutputBlobNames(vector<string>* names) const; | ||
virtual void RecurrentInputShapes(vector<BlobShape>* shapes) const; | ||
virtual void OutputBlobNames(vector<string>* names) const; | ||
}; | ||
|
||
/** | ||
* @brief A helper for LSTMLayer: computes a single timestep of the | ||
* non-linearity of the LSTM, producing the updated cell and hidden | ||
* states. | ||
*/ | ||
template <typename Dtype> | ||
class LSTMUnitLayer : public Layer<Dtype> { | ||
public: | ||
explicit LSTMUnitLayer(const LayerParameter& param) | ||
: Layer<Dtype>(param) {} | ||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
|
||
virtual inline const char* type() const { return "LSTMUnit"; } | ||
virtual inline int ExactNumBottomBlobs() const { return 3; } | ||
virtual inline int ExactNumTopBlobs() const { return 2; } | ||
|
||
virtual inline bool AllowForceBackward(const int bottom_index) const { | ||
// Can't propagate to sequence continuation indicators. | ||
return bottom_index != 2; | ||
} | ||
|
||
protected: | ||
/** | ||
* @param bottom input Blob vector (length 3) | ||
* -# @f$ (1 \times N \times D) @f$ | ||
* the previous timestep cell state @f$ c_{t-1} @f$ | ||
* -# @f$ (1 \times N \times 4D) @f$ | ||
* the "gate inputs" @f$ [i_t', f_t', o_t', g_t'] @f$ | ||
* -# @f$ (1 \times N) @f$ | ||
* the sequence continuation indicators @f$ \delta_t @f$ | ||
* @param top output Blob vector (length 2) | ||
* -# @f$ (1 \times N \times D) @f$ | ||
* the updated cell state @f$ c_t @f$, computed as: | ||
* i_t := \sigmoid[i_t'] | ||
* f_t := \sigmoid[f_t'] | ||
* o_t := \sigmoid[o_t'] | ||
* g_t := \tanh[g_t'] | ||
* c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t) | ||
* -# @f$ (1 \times N \times D) @f$ | ||
* the updated hidden state @f$ h_t @f$, computed as: | ||
* h_t := o_t .* \tanh[c_t] | ||
*/ | ||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
|
||
/** | ||
* @brief Computes the error gradient w.r.t. the LSTMUnit inputs. | ||
* | ||
* @param top output Blob vector (length 2), providing the error gradient with | ||
* respect to the outputs | ||
* -# @f$ (1 \times N \times D) @f$: | ||
* containing error gradients @f$ \frac{\partial E}{\partial c_t} @f$ | ||
* with respect to the updated cell state @f$ c_t @f$ | ||
* -# @f$ (1 \times N \times D) @f$: | ||
* containing error gradients @f$ \frac{\partial E}{\partial h_t} @f$ | ||
* with respect to the updated cell state @f$ h_t @f$ | ||
* @param propagate_down see Layer::Backward. | ||
* @param bottom input Blob vector (length 3), into which the error gradients | ||
* with respect to the LSTMUnit inputs @f$ c_{t-1} @f$ and the gate | ||
* inputs are computed. Computatation of the error gradients w.r.t. | ||
* the sequence indicators is not implemented. | ||
* -# @f$ (1 \times N \times D) @f$ | ||
* the error gradient w.r.t. the previous timestep cell state | ||
* @f$ c_{t-1} @f$ | ||
* -# @f$ (1 \times N \times 4D) @f$ | ||
* the error gradient w.r.t. the "gate inputs" | ||
* @f$ [ | ||
* \frac{\partial E}{\partial i_t} | ||
* \frac{\partial E}{\partial f_t} | ||
* \frac{\partial E}{\partial o_t} | ||
* \frac{\partial E}{\partial g_t} | ||
* ] @f$ | ||
* -# @f$ (1 \times 1 \times N) @f$ | ||
* the gradient w.r.t. the sequence continuation indicators | ||
* @f$ \delta_t @f$ is currently not computed. | ||
*/ | ||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); | ||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); | ||
|
||
/// @brief The hidden and output dimension. | ||
int hidden_dim_; | ||
Blob<Dtype> X_acts_; | ||
}; | ||
|
||
} // namespace caffe | ||
|
||
#endif // CAFFE_LSTM_LAYER_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "caffe/blob.hpp" | ||
#include "caffe/common.hpp" | ||
#include "caffe/filler.hpp" | ||
#include "caffe/layer.hpp" | ||
#include "caffe/layers/lstm_layer.hpp" | ||
#include "caffe/util/math_functions.hpp" | ||
|
||
namespace caffe { | ||
|
||
template <typename Dtype> | ||
void LSTMLayer<Dtype>::RecurrentInputBlobNames(vector<string>* names) const { | ||
names->resize(2); | ||
(*names)[0] = "h_0"; | ||
(*names)[1] = "c_0"; | ||
} | ||
|
||
template <typename Dtype> | ||
void LSTMLayer<Dtype>::RecurrentOutputBlobNames(vector<string>* names) const { | ||
names->resize(2); | ||
(*names)[0] = "h_" + format_int(this->T_); | ||
(*names)[1] = "c_T"; | ||
} | ||
|
||
template <typename Dtype> | ||
void LSTMLayer<Dtype>::RecurrentInputShapes(vector<BlobShape>* shapes) const { | ||
const int num_output = this->layer_param_.recurrent_param().num_output(); | ||
const int num_blobs = 2; | ||
shapes->resize(num_blobs); | ||
for (int i = 0; i < num_blobs; ++i) { | ||
(*shapes)[i].Clear(); | ||
(*shapes)[i].add_dim(1); // a single timestep | ||
(*shapes)[i].add_dim(this->N_); | ||
(*shapes)[i].add_dim(num_output); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void LSTMLayer<Dtype>::OutputBlobNames(vector<string>* names) const { | ||
names->resize(1); | ||
(*names)[0] = "h"; | ||
} | ||
|
||
template <typename Dtype> | ||
void LSTMLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const { | ||
const int num_output = this->layer_param_.recurrent_param().num_output(); | ||
CHECK_GT(num_output, 0) << "num_output must be positive"; | ||
const FillerParameter& weight_filler = | ||
this->layer_param_.recurrent_param().weight_filler(); | ||
const FillerParameter& bias_filler = | ||
this->layer_param_.recurrent_param().bias_filler(); | ||
|
||
// Add generic LayerParameter's (without bottoms/tops) of layer types we'll | ||
// use to save redundant code. | ||
LayerParameter hidden_param; | ||
hidden_param.set_type("InnerProduct"); | ||
hidden_param.mutable_inner_product_param()->set_num_output(num_output * 4); | ||
hidden_param.mutable_inner_product_param()->set_bias_term(false); | ||
hidden_param.mutable_inner_product_param()->set_axis(2); | ||
hidden_param.mutable_inner_product_param()-> | ||
mutable_weight_filler()->CopyFrom(weight_filler); | ||
|
||
LayerParameter biased_hidden_param(hidden_param); | ||
biased_hidden_param.mutable_inner_product_param()->set_bias_term(true); | ||
biased_hidden_param.mutable_inner_product_param()-> | ||
mutable_bias_filler()->CopyFrom(bias_filler); | ||
|
||
LayerParameter sum_param; | ||
sum_param.set_type("Eltwise"); | ||
sum_param.mutable_eltwise_param()->set_operation( | ||
EltwiseParameter_EltwiseOp_SUM); | ||
|
||
LayerParameter scale_param; | ||
scale_param.set_type("Scale"); | ||
scale_param.mutable_scale_param()->set_axis(0); | ||
|
||
LayerParameter slice_param; | ||
slice_param.set_type("Slice"); | ||
slice_param.mutable_slice_param()->set_axis(0); | ||
|
||
LayerParameter split_param; | ||
split_param.set_type("Split"); | ||
|
||
vector<BlobShape> input_shapes; | ||
RecurrentInputShapes(&input_shapes); | ||
CHECK_EQ(2, input_shapes.size()); | ||
|
||
LayerParameter* input_layer_param = net_param->add_layer(); | ||
input_layer_param->set_type("Input"); | ||
InputParameter* input_param = input_layer_param->mutable_input_param(); | ||
|
||
input_layer_param->add_top("c_0"); | ||
input_param->add_shape()->CopyFrom(input_shapes[0]); | ||
|
||
input_layer_param->add_top("h_0"); | ||
input_param->add_shape()->CopyFrom(input_shapes[1]); | ||
|
||
LayerParameter* cont_slice_param = net_param->add_layer(); | ||
cont_slice_param->CopyFrom(slice_param); | ||
cont_slice_param->set_name("cont_slice"); | ||
cont_slice_param->add_bottom("cont"); | ||
cont_slice_param->mutable_slice_param()->set_axis(0); | ||
|
||
// Add layer to transform all timesteps of x to the hidden state dimension. | ||
// W_xc_x = W_xc * x + b_c | ||
{ | ||
LayerParameter* x_transform_param = net_param->add_layer(); | ||
x_transform_param->CopyFrom(biased_hidden_param); | ||
x_transform_param->set_name("x_transform"); | ||
x_transform_param->add_param()->set_name("W_xc"); | ||
x_transform_param->add_param()->set_name("b_c"); | ||
x_transform_param->add_bottom("x"); | ||
x_transform_param->add_top("W_xc_x"); | ||
x_transform_param->add_propagate_down(true); | ||
} | ||
|
||
if (this->static_input_) { | ||
// Add layer to transform x_static to the gate dimension. | ||
// W_xc_x_static = W_xc_static * x_static | ||
LayerParameter* x_static_transform_param = net_param->add_layer(); | ||
x_static_transform_param->CopyFrom(hidden_param); | ||
x_static_transform_param->mutable_inner_product_param()->set_axis(1); | ||
x_static_transform_param->set_name("W_xc_x_static"); | ||
x_static_transform_param->add_param()->set_name("W_xc_static"); | ||
x_static_transform_param->add_bottom("x_static"); | ||
x_static_transform_param->add_top("W_xc_x_static_preshape"); | ||
x_static_transform_param->add_propagate_down(true); | ||
|
||
LayerParameter* reshape_param = net_param->add_layer(); | ||
reshape_param->set_type("Reshape"); | ||
BlobShape* new_shape = | ||
reshape_param->mutable_reshape_param()->mutable_shape(); | ||
new_shape->add_dim(1); // One timestep. | ||
// Should infer this->N as the dimension so we can reshape on batch size. | ||
new_shape->add_dim(-1); | ||
new_shape->add_dim( | ||
x_static_transform_param->inner_product_param().num_output()); | ||
reshape_param->set_name("W_xc_x_static_reshape"); | ||
reshape_param->add_bottom("W_xc_x_static_preshape"); | ||
reshape_param->add_top("W_xc_x_static"); | ||
} | ||
|
||
LayerParameter* x_slice_param = net_param->add_layer(); | ||
x_slice_param->CopyFrom(slice_param); | ||
x_slice_param->add_bottom("W_xc_x"); | ||
x_slice_param->set_name("W_xc_x_slice"); | ||
|
||
LayerParameter output_concat_layer; | ||
output_concat_layer.set_name("h_concat"); | ||
output_concat_layer.set_type("Concat"); | ||
output_concat_layer.add_top("h"); | ||
output_concat_layer.mutable_concat_param()->set_axis(0); | ||
|
||
for (int t = 1; t <= this->T_; ++t) { | ||
string tm1s = format_int(t - 1); | ||
string ts = format_int(t); | ||
|
||
cont_slice_param->add_top("cont_" + ts); | ||
x_slice_param->add_top("W_xc_x_" + ts); | ||
|
||
// Add layers to flush the hidden state when beginning a new | ||
// sequence, as indicated by cont_t. | ||
// h_conted_{t-1} := cont_t * h_{t-1} | ||
// | ||
// Normally, cont_t is binary (i.e., 0 or 1), so: | ||
// h_conted_{t-1} := h_{t-1} if cont_t == 1 | ||
// 0 otherwise | ||
{ | ||
LayerParameter* cont_h_param = net_param->add_layer(); | ||
cont_h_param->CopyFrom(scale_param); | ||
cont_h_param->set_name("h_conted_" + tm1s); | ||
cont_h_param->add_bottom("h_" + tm1s); | ||
cont_h_param->add_bottom("cont_" + ts); | ||
cont_h_param->add_top("h_conted_" + tm1s); | ||
} | ||
|
||
// Add layer to compute | ||
// W_hc_h_{t-1} := W_hc * h_conted_{t-1} | ||
{ | ||
LayerParameter* w_param = net_param->add_layer(); | ||
w_param->CopyFrom(hidden_param); | ||
w_param->set_name("transform_" + ts); | ||
w_param->add_param()->set_name("W_hc"); | ||
w_param->add_bottom("h_conted_" + tm1s); | ||
w_param->add_top("W_hc_h_" + tm1s); | ||
w_param->mutable_inner_product_param()->set_axis(2); | ||
} | ||
|
||
// Add the outputs of the linear transformations to compute the gate input. | ||
// gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c | ||
// = W_hc_h_{t-1} + W_xc_x_t + b_c | ||
{ | ||
LayerParameter* input_sum_layer = net_param->add_layer(); | ||
input_sum_layer->CopyFrom(sum_param); | ||
input_sum_layer->set_name("gate_input_" + ts); | ||
input_sum_layer->add_bottom("W_hc_h_" + tm1s); | ||
input_sum_layer->add_bottom("W_xc_x_" + ts); | ||
if (this->static_input_) { | ||
input_sum_layer->add_bottom("W_xc_x_static"); | ||
} | ||
input_sum_layer->add_top("gate_input_" + ts); | ||
} | ||
|
||
// Add LSTMUnit layer to compute the cell & hidden vectors c_t and h_t. | ||
// Inputs: c_{t-1}, gate_input_t = (i_t, f_t, o_t, g_t), cont_t | ||
// Outputs: c_t, h_t | ||
// [ i_t' ] | ||
// [ f_t' ] := gate_input_t | ||
// [ o_t' ] | ||
// [ g_t' ] | ||
// i_t := \sigmoid[i_t'] | ||
// f_t := \sigmoid[f_t'] | ||
// o_t := \sigmoid[o_t'] | ||
// g_t := \tanh[g_t'] | ||
// c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t) | ||
// h_t := o_t .* \tanh[c_t] | ||
{ | ||
LayerParameter* lstm_unit_param = net_param->add_layer(); | ||
lstm_unit_param->set_type("LSTMUnit"); | ||
lstm_unit_param->add_bottom("c_" + tm1s); | ||
lstm_unit_param->add_bottom("gate_input_" + ts); | ||
lstm_unit_param->add_bottom("cont_" + ts); | ||
lstm_unit_param->add_top("c_" + ts); | ||
lstm_unit_param->add_top("h_" + ts); | ||
lstm_unit_param->set_name("unit_" + ts); | ||
} | ||
output_concat_layer.add_bottom("h_" + ts); | ||
} // for (int t = 1; t <= this->T_; ++t) | ||
|
||
{ | ||
LayerParameter* c_T_copy_param = net_param->add_layer(); | ||
c_T_copy_param->CopyFrom(split_param); | ||
c_T_copy_param->add_bottom("c_" + format_int(this->T_)); | ||
c_T_copy_param->add_top("c_T"); | ||
} | ||
net_param->add_layer()->CopyFrom(output_concat_layer); | ||
} | ||
|
||
INSTANTIATE_CLASS(LSTMLayer); | ||
REGISTER_LAYER_CLASS(LSTM); | ||
|
||
} // namespace caffe |
Oops, something went wrong.