Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embed layer #2032

Merged
merged 4 commits into from
Aug 25, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,44 @@ class EltwiseLayer : public Layer<Dtype> {
bool stable_prod_grad_;
};

/**
* @brief A layer for learning "embeddings" of one-hot vector input.
* Equivalent to an InnerProductLayer with one-hot vectors as input, but
* for efficiency the input is the "hot" index of each column itself.
*
* TODO(dox): thorough documentation for Forward, Backward, and proto params.
*/
template <typename Dtype>
class EmbedLayer : public Layer<Dtype> {
public:
explicit EmbedLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "Embed"; }
virtual inline int ExactNumBottomBlobs() const { return 1; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

int M_;
int K_;
int N_;
bool bias_term_;
Blob<Dtype> bias_multiplier_;
};

/**
* @brief Takes two+ Blobs, interprets last Blob as a selector and
* filter remaining Blobs accordingly with selector data (0 means that
Expand Down
11 changes: 8 additions & 3 deletions include/caffe/test/test_gradient_check_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class GradientChecker {
void CheckGradientEltwise(Layer<Dtype>* layer,
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);

// Checks the gradient of a single output with respect to particular input
// blob(s). If check_bottom = i >= 0, check only the ith bottom Blob.
// If check_bottom == -1, check everything -- all bottom Blobs and all
// param Blobs. Otherwise (if check_bottom < -1), check only param Blobs.
void CheckGradientSingle(Layer<Dtype>* layer,
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top,
int check_bottom, int top_id, int top_data_id, bool element_wise = false);
Expand Down Expand Up @@ -83,21 +87,22 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>* layer,
// First, figure out what blobs we need to check against, and zero init
// parameter blobs.
vector<Blob<Dtype>*> blobs_to_check;
vector<bool> propagate_down(bottom.size(), check_bottom < 0);
vector<bool> propagate_down(bottom.size(), check_bottom == -1);
for (int i = 0; i < layer->blobs().size(); ++i) {
Blob<Dtype>* blob = layer->blobs()[i].get();
caffe_set(blob->count(), static_cast<Dtype>(0), blob->mutable_cpu_diff());
blobs_to_check.push_back(blob);
}
if (check_bottom < 0) {
if (check_bottom == -1) {
for (int i = 0; i < bottom.size(); ++i) {
blobs_to_check.push_back(bottom[i]);
}
} else {
} else if (check_bottom >= 0) {
CHECK_LT(check_bottom, bottom.size());
blobs_to_check.push_back(bottom[check_bottom]);
propagate_down[check_bottom] = true;
}
CHECK_GT(blobs_to_check.size(), 0) << "No blobs to check.";
// Compute the gradient analytically using Backward
Caffe::set_random_seed(seed_);
// Ignore the loss from the layer (it's just the weighted sum of the losses
Expand Down
35 changes: 35 additions & 0 deletions include/caffe/util/gpu_util.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef CAFFE_UTIL_GPU_UTIL_H_
#define CAFFE_UTIL_GPU_UTIL_H_

namespace caffe {

template <typename Dtype>
inline __device__ Dtype caffe_gpu_atomic_add(const Dtype val, Dtype* address);

template <>
inline __device__
float caffe_gpu_atomic_add(const float val, float* address) {
return atomicAdd(address, val);
}

// double atomicAdd implementation taken from:
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#axzz3PVCpVsEG
template <>
inline __device__
double caffe_gpu_atomic_add(const double val, double* address) {
unsigned long long int* address_as_ull = // NOLINT(runtime/int)
// NOLINT_NEXT_LINE(runtime/int)
reinterpret_cast<unsigned long long int*>(address);
unsigned long long int old = *address_as_ull; // NOLINT(runtime/int)
unsigned long long int assumed; // NOLINT(runtime/int)
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
}

} // namespace caffe

#endif // CAFFE_UTIL_GPU_UTIL_H_
122 changes: 122 additions & 0 deletions src/caffe/layers/embed_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/common_layers.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
N_ = this->layer_param_.embed_param().num_output();
CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive.";
K_ = this->layer_param_.embed_param().input_dim();
CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive.";
bias_term_ = this->layer_param_.embed_param().bias_term();
// Check if we need to set up the weights
if (this->blobs_.size() > 0) {
LOG(INFO) << "Skipping parameter initialization";
} else {
if (bias_term_) {
this->blobs_.resize(2);
} else {
this->blobs_.resize(1);
}
// Initialize the weights --
// transposed from InnerProductLayer for spatial locality.
vector<int> weight_shape(2);
weight_shape[0] = K_;
weight_shape[1] = N_;
this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.embed_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
// If necessary, initialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
this->layer_param_.embed_param().bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
} // parameter initialization
this->param_propagate_down_.resize(this->blobs_.size(), true);
}

template <typename Dtype>
void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Figure out the dimensions
M_ = bottom[0]->count();
vector<int> top_shape = bottom[0]->shape();
top_shape.push_back(N_);
top[0]->Reshape(top_shape);
// Set up the bias multiplier
if (bias_term_) {
vector<int> bias_shape(1, M_);
bias_multiplier_.Reshape(bias_shape);
caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data());
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
int index;
for (int n = 0; n < M_; ++n) {
index = static_cast<int>(bottom_data[n]);
DCHECK_GE(index, 0);
DCHECK_LT(index, K_);
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input";
caffe_copy(N_, weight + index * N_, top_data + n * N_);
}
if (bias_term_) {
const Dtype* bias = this->blobs_[1]->cpu_data();
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
bias_multiplier_.cpu_data(), bias, Dtype(1), top_data);
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
if (this->param_propagate_down_[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
int index;
for (int n = 0; n < M_; ++n) {
index = static_cast<int>(bottom_data[n]);
DCHECK_GE(index, 0);
DCHECK_LT(index, K_);
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n])
<< "non-integer input";
caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_);
}
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
bias_multiplier_.cpu_data(), Dtype(1), bias_diff);
}
}

#ifdef CPU_ONLY
STUB_GPU(EmbedLayer);
#endif

INSTANTIATE_CLASS(EmbedLayer);
REGISTER_LAYER_CLASS(Embed);

} // namespace caffe
85 changes: 85 additions & 0 deletions src/caffe/layers/embed_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/common_layers.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/gpu_util.cuh"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
__global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
const Dtype* weight, const int M, const int N, const int K,
Dtype* top_data) {
CUDA_KERNEL_LOOP(top_index, nthreads) {
const int n = top_index / N;
const int d = top_index % N;
const int index = static_cast<int>(bottom_data[n]);
const int weight_index = index * N + d;
top_data[top_index] = weight[weight_index];
}
}

template <typename Dtype>
__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
const Dtype* top_diff, const int M, const int N, const int K,
Dtype* weight_diff);

template <typename Dtype>
__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
const Dtype* top_diff, const int M, const int N, const int K,
Dtype* weight_diff) {
CUDA_KERNEL_LOOP(top_index, nthreads) {
const int n = top_index / N;
const int d = top_index % N;
const int index = static_cast<int>(bottom_data[n]);
const int weight_index = index * N + d;
caffe_gpu_atomic_add(top_diff[top_index], weight_diff + weight_index);
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
const int count = top[0]->count();
EmbedForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, weight, M_, N_, K_, top_data);
if (bias_term_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
bias_multiplier_.gpu_data(),
this->blobs_[1]->gpu_data(), Dtype(1), top_data);
}
}

template <typename Dtype>
void EmbedLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
if (this->param_propagate_down_[0]) {
const int top_count = top[0]->count();
const int count = this->blobs_[0]->count();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
EmbedBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(top_count), CAFFE_CUDA_NUM_THREADS>>>(
top_count, bottom_data, top_diff, M_, N_, K_, weight_diff);
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();
caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
bias_multiplier_.gpu_data(), Dtype(1), bias_diff);
}
}

INSTANTIATE_LAYER_GPU_FUNCS(EmbedLayer);

} // namespace caffe
18 changes: 17 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 137 (last added: reduction_param)
// LayerParameter next available layer-specific ID: 138 (last added: embed_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
Expand Down Expand Up @@ -332,6 +332,7 @@ message LayerParameter {
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
optional EmbedParameter embed_param = 137;
optional ExpParameter exp_param = 111;
optional FlattenParameter flatten_param = 135;
optional HDF5DataParameter hdf5_data_param = 112;
Expand Down Expand Up @@ -533,6 +534,21 @@ message EltwiseParameter {
optional bool stable_prod_grad = 3 [default = true];
}

// Message that stores parameters used by EmbedLayer
message EmbedParameter {
optional uint32 num_output = 1; // The number of outputs for the layer
// The input is given as integers to be interpreted as one-hot
// vector indices with dimension num_input. Hence num_input should be
// 1 greater than the maximum possible input value.
optional uint32 input_dim = 2;

optional bool bias_term = 3 [default = true]; // Whether to use a bias term
optional FillerParameter weight_filler = 4; // The filler for the weight
optional FillerParameter bias_filler = 5; // The filler for the bias

}

// Message that stores parameters used by ExpLayer
message ExpParameter {
// ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
// Or if base is set to the default (-1), base is set to e,
Expand Down
Loading