Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse Data Support #937

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ INCLUDE_DIRS += $(BUILD_INCLUDE_DIR) ./src ./include
ifneq ($(CPU_ONLY), 1)
INCLUDE_DIRS += $(CUDA_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR)
LIBRARIES := cudart cublas curand
LIBRARIES := cudart cublas cusparse curand
endif
LIBRARIES += glog gflags protobuf leveldb snappy \
lmdb boost_system hdf5_hl hdf5 m \
Expand Down
35 changes: 18 additions & 17 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Blob {
* an error; either Net::Forward or Net::Reshape need to be called to
* propagate the new input shape to higher layers.
*/
void Reshape(const int num, const int channels, const int height,
virtual void Reshape(const int num, const int channels, const int height,
const int width);
void ReshapeLike(const Blob& other);
inline int num() const { return num_; }
Expand Down Expand Up @@ -69,38 +69,39 @@ class Blob {
void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,
bool reshape = false);

inline Dtype data_at(const int n, const int c, const int h,
virtual inline Dtype data_at(const int n, const int c, const int h,
const int w) const {
return *(cpu_data() + offset(n, c, h, w));
}

inline Dtype diff_at(const int n, const int c, const int h,
virtual inline Dtype diff_at(const int n, const int c, const int h,
const int w) const {
return *(cpu_diff() + offset(n, c, h, w));
}

inline const shared_ptr<SyncedMemory>& data() const {
virtual inline const shared_ptr<SyncedMemory>& data() const {
CHECK(data_);
return data_;
}

inline const shared_ptr<SyncedMemory>& diff() const {
virtual inline const shared_ptr<SyncedMemory>& diff() const {
CHECK(diff_);
return diff_;
}

const Dtype* cpu_data() const;
void set_cpu_data(Dtype* data);
const Dtype* gpu_data() const;
const Dtype* cpu_diff() const;
const Dtype* gpu_diff() const;
Dtype* mutable_cpu_data();
Dtype* mutable_gpu_data();
Dtype* mutable_cpu_diff();
Dtype* mutable_gpu_diff();
void Update();
void FromProto(const BlobProto& proto);
void ToProto(BlobProto* proto, bool write_diff = false) const;
virtual const Dtype* cpu_data() const;
virtual void set_cpu_data(Dtype* data);
virtual void set_gpu_data(Dtype* data);
virtual const Dtype* gpu_data() const;
virtual const Dtype* cpu_diff() const;
virtual const Dtype* gpu_diff() const;
virtual Dtype* mutable_cpu_data();
virtual Dtype* mutable_gpu_data();
virtual Dtype* mutable_cpu_diff();
virtual Dtype* mutable_gpu_diff();
virtual void Update();
virtual void FromProto(const BlobProto& proto);
virtual void ToProto(BlobProto* proto, bool write_diff = false) const;

/// @brief Compute the sum of absolute values (L1 norm) of the data.
Dtype asum_data() const;
Expand Down
9 changes: 9 additions & 0 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <boost/shared_ptr.hpp>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <algorithm>

#include <cmath>
#include <fstream> // NOLINT(readability/streams)
Expand Down Expand Up @@ -127,6 +128,12 @@ class Caffe {
}
#ifndef CPU_ONLY
inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
inline static cusparseHandle_t cusparse_handle() {
return Get().cusparse_handle_;
}
inline static cusparseMatDescr_t cusparse_mat_descr() {
return Get().cusparse_mat_descr_;
}
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}
Expand Down Expand Up @@ -155,6 +162,8 @@ class Caffe {
protected:
#ifndef CPU_ONLY
cublasHandle_t cublas_handle_;
cusparseHandle_t cusparse_handle_;
cusparseMatDescr_t cusparse_mat_descr_;
curandGenerator_t curand_generator_;
#endif
shared_ptr<RNG> random_generator_;
Expand Down
31 changes: 31 additions & 0 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "caffe/loss_layers.hpp"
#include "caffe/neuron_layers.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/sparse_blob.hpp"

namespace caffe {

Expand Down Expand Up @@ -487,6 +488,36 @@ class SliceLayer : public Layer<Dtype> {
vector<int> slice_point_;
};

/**
* @brief Also known as a "fully-connected" layer, computes an inner product
* with a set of learned weights, and (optionally) adds biases.
* This layer also support sparse data (SparseBlob) as input
*
* TODO(dox): thorough documentation for Forward, Backward, and proto params.
*/
template<typename Dtype>
class SparseInnerProductLayer : public InnerProductLayer<Dtype> {
public:
explicit SparseInnerProductLayer(const LayerParameter& param)
: InnerProductLayer<Dtype>(param) {}

virtual inline LayerParameter_LayerType type() const {
return LayerParameter_LayerType_SPARSE_INNER_PRODUCT;
}

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom);
};

} // namespace caffe

#endif // CAFFE_COMMON_LAYERS_HPP_
50 changes: 50 additions & 0 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "boost/scoped_ptr.hpp"
#include "hdf5.h"
#include "leveldb/db.h"

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
Expand All @@ -16,6 +17,7 @@
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/sparse_blob.hpp"

namespace caffe {

Expand Down Expand Up @@ -104,6 +106,54 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
Dataset<string, Datum>::const_iterator iter_;
};


template <typename Dtype>
class DataLayerSparseInput : public Layer<Dtype>, public InternalThread {
public:
explicit DataLayerSparseInput(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~DataLayerSparseInput();

virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

// Data layers have no bottoms, so reshaping is trivial.
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {}
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}

virtual inline LayerParameter_LayerType type() const {
return LayerParameter_LayerType_DATA_SPARSE_INPUT;
}
virtual inline int ExactNumBottomBlobs() const { return 0; }
virtual inline int MinTopBlobs() const { return 1; }
virtual inline int MaxTopBlobs() const { return 2; }

protected:
virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
virtual void InternalThreadEntry();

Caffe::Phase phase_;
bool output_labels_;

int datum_size_;
shared_ptr<SparseBlob<Dtype> > prefetch_data_;
shared_ptr<SparseBlob<Dtype> > prefetch_data_copy_;
shared_ptr<Blob<Dtype> > prefetch_label_;
shared_ptr<Blob<Dtype> > prefetch_label_copy_;
shared_ptr<Dataset<string, SparseDatum> > dataset_;
Dataset<string, SparseDatum>::const_iterator iter_;
};

/**
* @brief Provides data to the Net generated by a Filler.
*
Expand Down
6 changes: 5 additions & 1 deletion include/caffe/dataset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ struct DefaultCoder<Message> {
template <>
struct DefaultCoder<caffe::Datum> : public DefaultCoder<Message> { };

template <>
struct DefaultCoder<caffe::SparseDatum> : public DefaultCoder<Message> { };

template <>
struct DefaultCoder<string> {
static bool serialize(string obj, string* serialized) {
Expand Down Expand Up @@ -236,6 +239,7 @@ class Dataset {
#define INSTANTIATE_DATASET(type) \
template class type<string, string>; \
template class type<string, vector<char> >; \
template class type<string, caffe::Datum>;
template class type<string, caffe::Datum>; \
template class type<string, caffe::SparseDatum>;

#endif // CAFFE_DATASET_H_
3 changes: 3 additions & 0 deletions include/caffe/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
}
}

template <typename Dtype>
Blob<Dtype>* GetTopBlob(const shared_ptr<LayerParameter>& param, int pos);

} // namespace caffe

#endif // CAFFE_LAYER_H_
117 changes: 117 additions & 0 deletions include/caffe/sparse_blob.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#ifndef CAFFE_SPARSE_BLOB_HPP_
#define CAFFE_SPARSE_BLOB_HPP_

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/syncedmem.hpp"

namespace caffe {

template<typename Dtype>
class SparseBlob : public Blob<Dtype> {
public:
SparseBlob()
: Blob<Dtype>(),
indices_(),
ptr_(),
nzz_(0) {
}

explicit SparseBlob(const int num, const int channels, const int nzz);

virtual void Reshape(const int num, const int channels, const int height,
const int width);

void Reshape(const int num, const int channels, const int nzz);

virtual void ReshapeLike(const Blob<Dtype>& other);

virtual inline int height() const {
return 1;
}
virtual inline int width() const {
return 1;
}
inline int nzz() const {
return nzz_;
}

virtual inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
LOG(FATAL)<< "Offset not supported in sparse blob.";
return 0;
}

virtual inline Dtype data_at(const int n, const int c, const int h,
const int w) const {
LOG(FATAL) << "data_at not implemented yet.";
return (Dtype)0;
}

virtual inline Dtype diff_at(const int n, const int c, const int h,
const int w) const {
LOG(FATAL) << "Diff data is not supported in sparse blob.";
return (Dtype)0;
}

inline const shared_ptr<SyncedMemory>& indices() const {
CHECK(indices_);
return indices_;
}

inline const shared_ptr<SyncedMemory>& ptr() const {
CHECK(ptr_);
return ptr_;
}

const int* cpu_indices() const;
const int* cpu_ptr() const;

const int* gpu_indices() const;
const int* gpu_ptr() const;

int* mutable_cpu_indices();
int* mutable_cpu_ptr();

int* mutable_gpu_indices();
int* mutable_gpu_ptr();

virtual void set_cpu_data(Dtype* data);
virtual void set_gpu_data(Dtype* data);

// the num and channels are assumed to be the same but
// nzz might change that is why is an argument
// also the actual size of data and indices might exceed nzz
// to allow for easy slicing.
// If total_size is -1 is assumed to be equal to nzz
void set_cpu_data(Dtype* data, int* indices, int* ptr, int nzz,
int total_size=-1);
void set_gpu_data(Dtype* data, int* indices, int* ptr, int nzz,
int total_size=-1);

virtual const Dtype* cpu_diff() const;
virtual const Dtype* gpu_diff() const;
virtual Dtype* mutable_cpu_diff();
virtual Dtype* mutable_gpu_diff();

virtual void ShareData(const Blob<Dtype>& other);
virtual void ShareDiff(const Blob<Dtype>& other);
virtual void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,
bool reshape = false);

virtual void Update();
virtual void FromProto(const BlobProto& proto);
virtual void ToProto(BlobProto* proto, bool write_diff = false) const;

protected:
shared_ptr<SyncedMemory> indices_;
shared_ptr<SyncedMemory> ptr_;
int nzz_;

DISABLE_COPY_AND_ASSIGN(SparseBlob);
}; // class SparseBlob

} // namespace caffe

#endif // CAFFE_SPARSE_BLOB_HPP_
Loading