Skip to content

Commit

Permalink
Add EmbedLayer for inner products with sparse input (one-hot vectors),
Browse files Browse the repository at this point in the history
with unit tests
  • Loading branch information
jeffdonahue committed Feb 16, 2015
1 parent 81a2ad4 commit aa8dc8d
Show file tree
Hide file tree
Showing 5 changed files with 439 additions and 1 deletion.
38 changes: 38 additions & 0 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,44 @@ class EltwiseLayer : public Layer<Dtype> {
bool stable_prod_grad_;
};

/**
* @brief A layer for learning "embeddings" of one-hot vector input.
* Equivalent to an InnerProductLayer with one-hot vectors as input, but
* for efficiency the input is the "hot" index of each column itself.
*
* TODO(dox): thorough documentation for Forward, Backward, and proto params.
*/
template <typename Dtype>
class EmbedLayer : public Layer<Dtype> {
public:
explicit EmbedLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "Embed"; }
virtual inline int ExactNumBottomBlobs() const { return 1; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

int M_;
int K_;
int N_;
bool bias_term_;
Blob<Dtype> bias_multiplier_;
};

/**
* @brief Reshapes the input Blob into flat vectors.
*
Expand Down
122 changes: 122 additions & 0 deletions src/caffe/layers/embed_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/common_layers.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
N_ = this->layer_param_.embed_param().num_output();
CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive.";
K_ = this->layer_param_.embed_param().input_dim();
CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive.";
bias_term_ = this->layer_param_.embed_param().bias_term();
// Check if we need to set up the weights
if (this->blobs_.size() > 0) {
LOG(INFO) << "Skipping parameter initialization";
} else {
if (bias_term_) {
this->blobs_.resize(2);
} else {
this->blobs_.resize(1);
}
// Initialize the weights --
// transposed from InnerProductLayer for spatial locality.
vector<int> weight_shape(2);
weight_shape[0] = K_;
weight_shape[1] = N_;
this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.embed_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
// If necessary, initialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
this->layer_param_.embed_param().bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
} // parameter initialization
this->param_propagate_down_.resize(this->blobs_.size(), true);
}

template <typename Dtype>
void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Figure out the dimensions
M_ = bottom[0]->count();
vector<int> top_shape = bottom[0]->shape();
top_shape.push_back(N_);
top[0]->Reshape(top_shape);
// Set up the bias multiplier
if (bias_term_) {
vector<int> bias_shape(1, M_);
bias_multiplier_.Reshape(bias_shape);
caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data());
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
int index;
for (int n = 0; n < M_; ++n) {
index = static_cast<int>(bottom_data[n]);
DCHECK_GE(index, 0);
DCHECK_LT(index, K_);
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input";
caffe_copy(N_, weight + index * N_, top_data + n * N_);
}
if (bias_term_) {
const Dtype* bias = this->blobs_[1]->cpu_data();
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
bias_multiplier_.cpu_data(), bias, Dtype(1), top_data);
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
if (this->param_propagate_down_[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
int index;
for (int n = 0; n < M_; ++n) {
index = static_cast<int>(bottom_data[n]);
DCHECK_GE(index, 0);
DCHECK_LT(index, K_);
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n])
<< "non-integer input";
caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_);
}
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
bias_multiplier_.cpu_data(), Dtype(1), bias_diff);
}
}

#ifdef CPU_ONLY
STUB_GPU(EmbedLayer);
#endif

INSTANTIATE_CLASS(EmbedLayer);
REGISTER_LAYER_CLASS(Embed);

} // namespace caffe
80 changes: 80 additions & 0 deletions src/caffe/layers/embed_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/common_layers.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
__global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
const Dtype* weight, const int M, const int N, const int K,
Dtype* top_data) {
CUDA_KERNEL_LOOP(top_index, nthreads) {
const int n = top_index / N;
const int d = top_index % N;
const int index = static_cast<int>(bottom_data[n]);
const int weight_index = index * N + d;
top_data[top_index] = weight[weight_index];
}
}

template <typename Dtype>
__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
const Dtype* top_diff, const int M, const int N, const int K,
Dtype* weight_diff) {
CUDA_KERNEL_LOOP(weight_index, nthreads) {
const int index = weight_index / N;
const int output_index = weight_index % N;
for (int n = 0; n < M; ++n) {
if (static_cast<int>(bottom_data[n]) == index) {
weight_diff[weight_index] += top_diff[n * N + output_index];
}
}
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
const int count = top[0]->count();
EmbedForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, weight, M_, N_, K_, top_data);
if (bias_term_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
bias_multiplier_.gpu_data(),
this->blobs_[1]->gpu_data(), Dtype(1), top_data);
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
if (this->param_propagate_down_[0]) {
const int count = this->blobs_[0]->count();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
EmbedBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_diff, M_, N_, K_, weight_diff);
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();
caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
bias_multiplier_.gpu_data(), Dtype(1), bias_diff);
}
}

INSTANTIATE_LAYER_GPU_FUNCS(EmbedLayer);

} // namespace caffe
17 changes: 16 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 129 (last added: window_data_param)
// LayerParameter next available layer-specific ID: 131 (last added: embed_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
Expand Down Expand Up @@ -309,6 +309,7 @@ message LayerParameter {
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
optional EmbedParameter embed_param = 130;
optional ExpParameter exp_param = 111;
optional HDF5DataParameter hdf5_data_param = 112;
optional HDF5OutputParameter hdf5_output_param = 113;
Expand Down Expand Up @@ -484,6 +485,20 @@ message EltwiseParameter {
optional bool stable_prod_grad = 3 [default = true];
}

// Message that stores parameters used by EmbedLayer
message EmbedParameter {
optional uint32 num_output = 1; // The number of outputs for the layer
// The input is given as integers to be interpreted as one-hot
// vector indices with dimension num_input. Hence num_input should be
// 1 greater than the maximum possible input value.
optional uint32 input_dim = 2;

optional bool bias_term = 3 [default = true]; // Whether to use a bias term
optional FillerParameter weight_filler = 4; // The filler for the weight
optional FillerParameter bias_filler = 5; // The filler for the bias

}

// Message that stores parameters used by ExpLayer
message ExpParameter {
// ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
Expand Down
Loading

0 comments on commit aa8dc8d

Please sign in to comment.