-
Notifications
You must be signed in to change notification settings - Fork 18.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2032 from jeffdonahue/embed-layer
Embed layer for lookup table of one hot encodings
- Loading branch information
Showing
7 changed files
with
488 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#ifndef CAFFE_UTIL_GPU_UTIL_H_ | ||
#define CAFFE_UTIL_GPU_UTIL_H_ | ||
|
||
namespace caffe { | ||
|
||
template <typename Dtype> | ||
inline __device__ Dtype caffe_gpu_atomic_add(const Dtype val, Dtype* address); | ||
|
||
template <> | ||
inline __device__ | ||
float caffe_gpu_atomic_add(const float val, float* address) { | ||
return atomicAdd(address, val); | ||
} | ||
|
||
// double atomicAdd implementation taken from: | ||
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#axzz3PVCpVsEG | ||
template <> | ||
inline __device__ | ||
double caffe_gpu_atomic_add(const double val, double* address) { | ||
unsigned long long int* address_as_ull = // NOLINT(runtime/int) | ||
// NOLINT_NEXT_LINE(runtime/int) | ||
reinterpret_cast<unsigned long long int*>(address); | ||
unsigned long long int old = *address_as_ull; // NOLINT(runtime/int) | ||
unsigned long long int assumed; // NOLINT(runtime/int) | ||
do { | ||
assumed = old; | ||
old = atomicCAS(address_as_ull, assumed, | ||
__double_as_longlong(val + __longlong_as_double(assumed))); | ||
} while (assumed != old); | ||
return __longlong_as_double(old); | ||
} | ||
|
||
} // namespace caffe | ||
|
||
#endif // CAFFE_UTIL_GPU_UTIL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#include <vector> | ||
|
||
#include "caffe/blob.hpp" | ||
#include "caffe/common.hpp" | ||
#include "caffe/common_layers.hpp" | ||
#include "caffe/filler.hpp" | ||
#include "caffe/layer.hpp" | ||
#include "caffe/util/math_functions.hpp" | ||
|
||
namespace caffe { | ||
|
||
template <typename Dtype> | ||
void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
N_ = this->layer_param_.embed_param().num_output(); | ||
CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive."; | ||
K_ = this->layer_param_.embed_param().input_dim(); | ||
CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive."; | ||
bias_term_ = this->layer_param_.embed_param().bias_term(); | ||
// Check if we need to set up the weights | ||
if (this->blobs_.size() > 0) { | ||
LOG(INFO) << "Skipping parameter initialization"; | ||
} else { | ||
if (bias_term_) { | ||
this->blobs_.resize(2); | ||
} else { | ||
this->blobs_.resize(1); | ||
} | ||
// Initialize the weights -- | ||
// transposed from InnerProductLayer for spatial locality. | ||
vector<int> weight_shape(2); | ||
weight_shape[0] = K_; | ||
weight_shape[1] = N_; | ||
this->blobs_[0].reset(new Blob<Dtype>(weight_shape)); | ||
// fill the weights | ||
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>( | ||
this->layer_param_.embed_param().weight_filler())); | ||
weight_filler->Fill(this->blobs_[0].get()); | ||
// If necessary, initialize and fill the bias term | ||
if (bias_term_) { | ||
vector<int> bias_shape(1, N_); | ||
this->blobs_[1].reset(new Blob<Dtype>(bias_shape)); | ||
shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>( | ||
this->layer_param_.embed_param().bias_filler())); | ||
bias_filler->Fill(this->blobs_[1].get()); | ||
} | ||
} // parameter initialization | ||
this->param_propagate_down_.resize(this->blobs_.size(), true); | ||
} | ||
|
||
template <typename Dtype> | ||
void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
// Figure out the dimensions | ||
M_ = bottom[0]->count(); | ||
vector<int> top_shape = bottom[0]->shape(); | ||
top_shape.push_back(N_); | ||
top[0]->Reshape(top_shape); | ||
// Set up the bias multiplier | ||
if (bias_term_) { | ||
vector<int> bias_shape(1, M_); | ||
bias_multiplier_.Reshape(bias_shape); | ||
caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data()); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
const Dtype* bottom_data = bottom[0]->cpu_data(); | ||
const Dtype* weight = this->blobs_[0]->cpu_data(); | ||
Dtype* top_data = top[0]->mutable_cpu_data(); | ||
int index; | ||
for (int n = 0; n < M_; ++n) { | ||
index = static_cast<int>(bottom_data[n]); | ||
DCHECK_GE(index, 0); | ||
DCHECK_LT(index, K_); | ||
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input"; | ||
caffe_copy(N_, weight + index * N_, top_data + n * N_); | ||
} | ||
if (bias_term_) { | ||
const Dtype* bias = this->blobs_[1]->cpu_data(); | ||
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1), | ||
bias_multiplier_.cpu_data(), bias, Dtype(1), top_data); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { | ||
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; | ||
if (this->param_propagate_down_[0]) { | ||
const Dtype* top_diff = top[0]->cpu_diff(); | ||
const Dtype* bottom_data = bottom[0]->cpu_data(); | ||
// Gradient with respect to weight | ||
Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); | ||
int index; | ||
for (int n = 0; n < M_; ++n) { | ||
index = static_cast<int>(bottom_data[n]); | ||
DCHECK_GE(index, 0); | ||
DCHECK_LT(index, K_); | ||
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) | ||
<< "non-integer input"; | ||
caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_); | ||
} | ||
} | ||
if (bias_term_ && this->param_propagate_down_[1]) { | ||
const Dtype* top_diff = top[0]->cpu_diff(); | ||
Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); | ||
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff, | ||
bias_multiplier_.cpu_data(), Dtype(1), bias_diff); | ||
} | ||
} | ||
|
||
#ifdef CPU_ONLY | ||
STUB_GPU(EmbedLayer); | ||
#endif | ||
|
||
INSTANTIATE_CLASS(EmbedLayer); | ||
REGISTER_LAYER_CLASS(Embed); | ||
|
||
} // namespace caffe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#include <vector> | ||
|
||
#include "caffe/blob.hpp" | ||
#include "caffe/common.hpp" | ||
#include "caffe/common_layers.hpp" | ||
#include "caffe/filler.hpp" | ||
#include "caffe/layer.hpp" | ||
#include "caffe/util/gpu_util.cuh" | ||
#include "caffe/util/math_functions.hpp" | ||
|
||
namespace caffe { | ||
|
||
template <typename Dtype> | ||
__global__ void EmbedForward(const int nthreads, const Dtype* bottom_data, | ||
const Dtype* weight, const int M, const int N, const int K, | ||
Dtype* top_data) { | ||
CUDA_KERNEL_LOOP(top_index, nthreads) { | ||
const int n = top_index / N; | ||
const int d = top_index % N; | ||
const int index = static_cast<int>(bottom_data[n]); | ||
const int weight_index = index * N + d; | ||
top_data[top_index] = weight[weight_index]; | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data, | ||
const Dtype* top_diff, const int M, const int N, const int K, | ||
Dtype* weight_diff); | ||
|
||
template <typename Dtype> | ||
__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data, | ||
const Dtype* top_diff, const int M, const int N, const int K, | ||
Dtype* weight_diff) { | ||
CUDA_KERNEL_LOOP(top_index, nthreads) { | ||
const int n = top_index / N; | ||
const int d = top_index % N; | ||
const int index = static_cast<int>(bottom_data[n]); | ||
const int weight_index = index * N + d; | ||
caffe_gpu_atomic_add(top_diff[top_index], weight_diff + weight_index); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void EmbedLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
const Dtype* bottom_data = bottom[0]->gpu_data(); | ||
Dtype* top_data = top[0]->mutable_gpu_data(); | ||
const Dtype* weight = this->blobs_[0]->gpu_data(); | ||
const int count = top[0]->count(); | ||
EmbedForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators) | ||
<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>( | ||
count, bottom_data, weight, M_, N_, K_, top_data); | ||
if (bias_term_) { | ||
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1), | ||
bias_multiplier_.gpu_data(), | ||
this->blobs_[1]->gpu_data(), Dtype(1), top_data); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void EmbedLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { | ||
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; | ||
if (this->param_propagate_down_[0]) { | ||
const int top_count = top[0]->count(); | ||
const int count = this->blobs_[0]->count(); | ||
const Dtype* top_diff = top[0]->gpu_diff(); | ||
const Dtype* bottom_data = bottom[0]->gpu_data(); | ||
Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); | ||
EmbedBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators) | ||
<<<CAFFE_GET_BLOCKS(top_count), CAFFE_CUDA_NUM_THREADS>>>( | ||
top_count, bottom_data, top_diff, M_, N_, K_, weight_diff); | ||
} | ||
if (bias_term_ && this->param_propagate_down_[1]) { | ||
const Dtype* top_diff = top[0]->gpu_diff(); | ||
Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); | ||
caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff, | ||
bias_multiplier_.gpu_data(), Dtype(1), bias_diff); | ||
} | ||
} | ||
|
||
INSTANTIATE_LAYER_GPU_FUNCS(EmbedLayer); | ||
|
||
} // namespace caffe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.