diff --git a/Makefile b/Makefile index 1e70c2d5c6e..61fa8fce7d1 100644 --- a/Makefile +++ b/Makefile @@ -165,7 +165,7 @@ INCLUDE_DIRS += $(BUILD_INCLUDE_DIR) ./src ./include ifneq ($(CPU_ONLY), 1) INCLUDE_DIRS += $(CUDA_INCLUDE_DIR) LIBRARY_DIRS += $(CUDA_LIB_DIR) - LIBRARIES := cudart cublas curand + LIBRARIES := cudart cublas cusparse curand endif LIBRARIES += glog gflags protobuf leveldb snappy \ lmdb boost_system hdf5_hl hdf5 m \ diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index ef10aea53f0..f1955870af6 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -37,7 +37,7 @@ class Blob { * an error; either Net::Forward or Net::Reshape need to be called to * propagate the new input shape to higher layers. */ - void Reshape(const int num, const int channels, const int height, + virtual void Reshape(const int num, const int channels, const int height, const int width); void ReshapeLike(const Blob& other); inline int num() const { return num_; } @@ -69,38 +69,39 @@ class Blob { void CopyFrom(const Blob& source, bool copy_diff = false, bool reshape = false); - inline Dtype data_at(const int n, const int c, const int h, + virtual inline Dtype data_at(const int n, const int c, const int h, const int w) const { return *(cpu_data() + offset(n, c, h, w)); } - inline Dtype diff_at(const int n, const int c, const int h, + virtual inline Dtype diff_at(const int n, const int c, const int h, const int w) const { return *(cpu_diff() + offset(n, c, h, w)); } - inline const shared_ptr& data() const { + virtual inline const shared_ptr& data() const { CHECK(data_); return data_; } - inline const shared_ptr& diff() const { + virtual inline const shared_ptr& diff() const { CHECK(diff_); return diff_; } - const Dtype* cpu_data() const; - void set_cpu_data(Dtype* data); - const Dtype* gpu_data() const; - const Dtype* cpu_diff() const; - const Dtype* gpu_diff() const; - Dtype* mutable_cpu_data(); - Dtype* mutable_gpu_data(); - Dtype* mutable_cpu_diff(); - Dtype* mutable_gpu_diff(); - void Update(); - void FromProto(const BlobProto& proto); - void ToProto(BlobProto* proto, bool write_diff = false) const; + virtual const Dtype* cpu_data() const; + virtual void set_cpu_data(Dtype* data); + virtual void set_gpu_data(Dtype* data); + virtual const Dtype* gpu_data() const; + virtual const Dtype* cpu_diff() const; + virtual const Dtype* gpu_diff() const; + virtual Dtype* mutable_cpu_data(); + virtual Dtype* mutable_gpu_data(); + virtual Dtype* mutable_cpu_diff(); + virtual Dtype* mutable_gpu_diff(); + virtual void Update(); + virtual void FromProto(const BlobProto& proto); + virtual void ToProto(BlobProto* proto, bool write_diff = false) const; /// @brief Compute the sum of absolute values (L1 norm) of the data. Dtype asum_data() const; diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 81b2e9ae101..e7b7a3e5a19 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include // NOLINT(readability/streams) @@ -127,6 +128,12 @@ class Caffe { } #ifndef CPU_ONLY inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; } + inline static cusparseHandle_t cusparse_handle() { + return Get().cusparse_handle_; + } + inline static cusparseMatDescr_t cusparse_mat_descr() { + return Get().cusparse_mat_descr_; + } inline static curandGenerator_t curand_generator() { return Get().curand_generator_; } @@ -155,6 +162,8 @@ class Caffe { protected: #ifndef CPU_ONLY cublasHandle_t cublas_handle_; + cusparseHandle_t cusparse_handle_; + cusparseMatDescr_t cusparse_mat_descr_; curandGenerator_t curand_generator_; #endif shared_ptr random_generator_; diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 9718b825b14..e020bb3c4f5 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -12,6 +12,7 @@ #include "caffe/loss_layers.hpp" #include "caffe/neuron_layers.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/sparse_blob.hpp" namespace caffe { @@ -487,6 +488,36 @@ class SliceLayer : public Layer { vector slice_point_; }; +/** + * @brief Also known as a "fully-connected" layer, computes an inner product + * with a set of learned weights, and (optionally) adds biases. + * This layer also support sparse data (SparseBlob) as input + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class SparseInnerProductLayer : public InnerProductLayer { + public: + explicit SparseInnerProductLayer(const LayerParameter& param) + : InnerProductLayer(param) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_SPARSE_INNER_PRODUCT; + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom); +}; + } // namespace caffe #endif // CAFFE_COMMON_LAYERS_HPP_ diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 34b9b30aa3e..1eab0e19848 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -7,6 +7,7 @@ #include "boost/scoped_ptr.hpp" #include "hdf5.h" +#include "leveldb/db.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" @@ -16,6 +17,7 @@ #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/sparse_blob.hpp" namespace caffe { @@ -104,6 +106,58 @@ class DataLayer : public BasePrefetchingDataLayer { Dataset::const_iterator iter_; }; +template +void* DataLayerSparseInputPrefetch(void* layer_pointer); + +template +class DataLayerSparseInput : public Layer { + // The function used to perform prefetching. + friend void* DataLayerSparseInputPrefetch(void* layer_pointer); + + virtual void Reshape(const vector*>& bottom, + const vector*>& top) {} + + public: + explicit DataLayerSparseInput(const LayerParameter& param) + : Layer(param) { + } + virtual ~DataLayerSparseInput(); + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + return; + } + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + return; + } + + virtual void CreatePrefetchThread(); + virtual void JoinPrefetchThread(); + + shared_ptr db_; + shared_ptr iter_; + int datum_size_; + + pthread_t thread_; + shared_ptr > prefetch_data_; + shared_ptr > prefetch_data_copy_; + shared_ptr > prefetch_label_; + shared_ptr > prefetch_label_copy_; + + bool output_labels_; + Caffe::Phase phase_; +}; + /** * @brief Provides data to the Net generated by a Filler. * diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index 8a8330bca57..dc094bc6d39 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -470,6 +470,9 @@ void Layer::ToProto(LayerParameter* param, bool write_diff) { } } +template +Blob* GetTopBlob(const shared_ptr& param, int pos); + } // namespace caffe #endif // CAFFE_LAYER_H_ diff --git a/include/caffe/sparse_blob.hpp b/include/caffe/sparse_blob.hpp new file mode 100644 index 00000000000..4711e5caf25 --- /dev/null +++ b/include/caffe/sparse_blob.hpp @@ -0,0 +1,117 @@ +#ifndef CAFFE_SPARSE_BLOB_HPP_ +#define CAFFE_SPARSE_BLOB_HPP_ + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/syncedmem.hpp" + +namespace caffe { + +template +class SparseBlob : public Blob { + public: + SparseBlob() + : Blob(), + indices_(), + ptr_(), + nzz_(0) { + } + + explicit SparseBlob(const int num, const int channels, const int nzz); + + virtual void Reshape(const int num, const int channels, const int height, + const int width); + + void Reshape(const int num, const int channels, const int nzz); + + virtual void ReshapeLike(const Blob& other); + + virtual inline int height() const { + return 1; + } + virtual inline int width() const { + return 1; + } + inline int nzz() const { + return nzz_; + } + + virtual inline int offset(const int n, const int c = 0, const int h = 0, + const int w = 0) const { + LOG(FATAL)<< "Offset not supported in sparse blob."; + return 0; + } + + virtual inline Dtype data_at(const int n, const int c, const int h, + const int w) const { + LOG(FATAL) << "data_at not implemented yet."; + return (Dtype)0; + } + + virtual inline Dtype diff_at(const int n, const int c, const int h, + const int w) const { + LOG(FATAL) << "Diff data is not supported in sparse blob."; + return (Dtype)0; + } + + inline const shared_ptr& indices() const { + CHECK(indices_); + return indices_; + } + + inline const shared_ptr& ptr() const { + CHECK(ptr_); + return ptr_; + } + + const int* cpu_indices() const; + const int* cpu_ptr() const; + + const int* gpu_indices() const; + const int* gpu_ptr() const; + + int* mutable_cpu_indices(); + int* mutable_cpu_ptr(); + + int* mutable_gpu_indices(); + int* mutable_gpu_ptr(); + + virtual void set_cpu_data(Dtype* data); + virtual void set_gpu_data(Dtype* data); + + // the num and channels are assumed to be the same but + // nzz might change that is why is an argument + // also the actual size of data and indices might exceed nzz + // to allow for easy slicing. + // If total_size is -1 is assumed to be equal to nzz + void set_cpu_data(Dtype* data, int* indices, int* ptr, int nzz, + int total_size=-1); + void set_gpu_data(Dtype* data, int* indices, int* ptr, int nzz, + int total_size=-1); + + virtual const Dtype* cpu_diff() const; + virtual const Dtype* gpu_diff() const; + virtual Dtype* mutable_cpu_diff(); + virtual Dtype* mutable_gpu_diff(); + + virtual void ShareData(const Blob& other); + virtual void ShareDiff(const Blob& other); + virtual void CopyFrom(const Blob& source, bool copy_diff = false, + bool reshape = false); + + virtual void Update(); + virtual void FromProto(const BlobProto& proto); + virtual void ToProto(BlobProto* proto, bool write_diff = false) const; + + protected: + shared_ptr indices_; + shared_ptr ptr_; + int nzz_; + + DISABLE_COPY_AND_ASSIGN(SparseBlob); +}; // class SparseBlob + +} // namespace caffe + +#endif // CAFFE_SPARSE_BLOB_HPP_ diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp index db8d0e80e12..44b2f9bf8a0 100644 --- a/include/caffe/syncedmem.hpp +++ b/include/caffe/syncedmem.hpp @@ -41,13 +41,16 @@ class SyncedMemory { public: SyncedMemory() : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false) {} explicit SyncedMemory(size_t size) : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false) {} ~SyncedMemory(); const void* cpu_data(); - void set_cpu_data(void* data); + + // if size if -1 the size is not changed + void set_cpu_data(void* data, int size=-1); + void set_gpu_data(void* data, int size=-1); const void* gpu_data(); void* mutable_cpu_data(); void* mutable_gpu_data(); @@ -58,11 +61,13 @@ class SyncedMemory { private: void to_cpu(); void to_gpu(); + void clear_data(); void* cpu_ptr_; void* gpu_ptr_; size_t size_; SyncedHead head_; bool own_cpu_data_; + bool own_gpu_data_; DISABLE_COPY_AND_ASSIGN(SyncedMemory); }; // class SyncedMemory diff --git a/include/caffe/test/test_gradient_check_util.hpp b/include/caffe/test/test_gradient_check_util.hpp index 22937711b58..68162f69d75 100644 --- a/include/caffe/test/test_gradient_check_util.hpp +++ b/include/caffe/test/test_gradient_check_util.hpp @@ -82,15 +82,17 @@ void GradientChecker::CheckGradientSingle(Layer* layer, } // First, figure out what blobs we need to check against. vector*> blobs_to_check; - vector propagate_down(bottom.size(), check_bottom < 0); + // if check_bottom is < -1 no bottom layer will be checked + vector propagate_down(bottom.size(), check_bottom == -1); for (int i = 0; i < layer->blobs().size(); ++i) { blobs_to_check.push_back(layer->blobs()[i].get()); } - if (check_bottom < 0) { + // if check_bottom is < -1 no bottom layer will be checked + if (check_bottom == -1) { for (int i = 0; i < bottom.size(); ++i) { blobs_to_check.push_back(bottom[i]); } - } else { + } else if (check_bottom >= 0) { CHECK_LT(check_bottom, bottom.size()); blobs_to_check.push_back(bottom[check_bottom]); propagate_down[check_bottom] = true; diff --git a/include/caffe/util/device_alternate.hpp b/include/caffe/util/device_alternate.hpp index 5a45691bb17..4d54db81522 100644 --- a/include/caffe/util/device_alternate.hpp +++ b/include/caffe/util/device_alternate.hpp @@ -35,6 +35,7 @@ void classname::funcname##_##gpu(const vector*>& top, \ #include #include #include +#include #include // cuda driver types #ifdef USE_CUDNN // cuDNN acceleration library. #include "caffe/util/cudnn.hpp" @@ -59,6 +60,13 @@ void classname::funcname##_##gpu(const vector*>& top, \ << caffe::cublasGetErrorString(status); \ } while (0) +#define CUSPARSE_CHECK(condition) \ + do { \ + cusparseStatus_t status = condition; \ + CHECK_EQ(status, CUSPARSE_STATUS_SUCCESS) << " " \ + << caffe::cusparseGetErrorString(status); \ + } while (0) + #define CURAND_CHECK(condition) \ do { \ curandStatus_t status = condition; \ @@ -79,6 +87,7 @@ namespace caffe { // CUDA: library error reporting. const char* cublasGetErrorString(cublasStatus_t error); +const char* cusparseGetErrorString(cusparseStatus_t error); const char* curandGetErrorString(curandStatus_t error); // CUDA: thread number configuration. diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index 12823f8da0c..7f8859e2cce 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -20,6 +20,14 @@ void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C); +template +void caffe_cpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const Dtype alpha, const int nzz, + const Dtype* A, const int* indices, const int* ptr, + const Dtype* B, const Dtype beta, Dtype* C, + const CBLAS_ORDER orderC); + template void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta, @@ -27,7 +35,7 @@ void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, template void caffe_axpy(const int N, const Dtype alpha, const Dtype* X, - Dtype* Y); + Dtype* Y, const int ldx = 1, const int ldy = 1); template void caffe_cpu_axpby(const int N, const Dtype alpha, const Dtype* X, @@ -153,6 +161,13 @@ void caffe_gpu_gemm(const CBLAS_TRANSPOSE TransA, const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C); +template +void caffe_gpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const Dtype alpha, int nzz, const Dtype* A, + const int* indices, const int* ptr, const Dtype* B, + const Dtype beta, Dtype* C, const CBLAS_ORDER orderC); + template void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta, diff --git a/scripts/travis/travis_install.sh b/scripts/travis/travis_install.sh index e17f253ecdc..7f9033f8e86 100755 --- a/scripts/travis/travis_install.sh +++ b/scripts/travis/travis_install.sh @@ -39,7 +39,7 @@ if $WITH_CUDA; then apt-get -y update # Install the minimal CUDA subpackages required to test Caffe build. # For a full CUDA installation, add 'cuda' to the list of packages. - apt-get -y install cuda-core-6-5 cuda-cublas-6-5 cuda-cublas-dev-6-5 cuda-cudart-6-5 cuda-cudart-dev-6-5 cuda-curand-6-5 cuda-curand-dev-6-5 + apt-get -y install cuda-core-6-5 cuda-cublas-6-5 cuda-cublas-dev-6-5 cuda-cusparse-6-5 cuda-cusparse-dev-6-5 cuda-cudart-6-5 cuda-cudart-dev-6-5 cuda-curand-6-5 cuda-curand-dev-6-5 # Create CUDA symlink at /usr/local/cuda # (This would normally be created by the CUDA installer, but we create it # manually since we did a partial installation.) diff --git a/src/caffe/CMakeLists.txt b/src/caffe/CMakeLists.txt index dda072688f8..eaf4a079b94 100644 --- a/src/caffe/CMakeLists.txt +++ b/src/caffe/CMakeLists.txt @@ -26,6 +26,7 @@ if(NOT CPU_ONLY) add_dependencies(caffe_cu proto) target_link_libraries(caffe caffe_cu ${CUDA_CUBLAS_LIBRARIES} + ${CUDA_cusparse_LIBRARY} ${CUDA_curand_LIBRARY} ) endif() diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index cfffc379eb1..df5363668da 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -49,6 +49,12 @@ void Blob::set_cpu_data(Dtype* data) { data_->set_cpu_data(data); } +template +void Blob::set_gpu_data(Dtype* data) { + CHECK(data); + data_->set_gpu_data(data); +} + template const Dtype* Blob::gpu_data() const { CHECK(data_); diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 834d5694aad..b532cc19b9d 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -85,13 +85,27 @@ void* Caffe::RNG::generator() { #else // Normal GPU + CPU Caffe. Caffe::Caffe() - : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(), - mode_(Caffe::CPU), phase_(Caffe::TRAIN) { + : cublas_handle_(NULL), + cusparse_handle_(NULL), + curand_generator_(NULL), + random_generator_(), + mode_(Caffe::CPU), + phase_(Caffe::TRAIN) { // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "Cannot create Cublas handle. Cublas won't be available."; } + if (cusparseCreate(&cusparse_handle_) != CUSPARSE_STATUS_SUCCESS) { + LOG(ERROR) << "Cannot create Cusparse handle. Cusparse won't be available."; + } + if (cusparseCreateMatDescr(&cusparse_mat_descr_) != CUSPARSE_STATUS_SUCCESS) { + LOG(ERROR) << "Cannot create Cusparse mat description. " + "Cusparse won't be available."; + } else { + cusparseSetMatType(cusparse_mat_descr_, CUSPARSE_MATRIX_TYPE_GENERAL); + cusparseSetMatIndexBase(cusparse_mat_descr_, CUSPARSE_INDEX_BASE_ZERO); + } // Try to create a curand handler. if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT) != CURAND_STATUS_SUCCESS || @@ -103,6 +117,10 @@ Caffe::Caffe() Caffe::~Caffe() { if (cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_)); + if (cusparse_handle_) + CUSPARSE_CHECK(cusparseDestroy(cusparse_handle_)); + if (cusparse_mat_descr_) + CUSPARSE_CHECK(cusparseDestroyMatDescr(cusparse_mat_descr_)); if (curand_generator_) { CURAND_CHECK(curandDestroyGenerator(curand_generator_)); } @@ -136,10 +154,18 @@ void Caffe::SetDevice(const int device_id) { // may perform initialization using the GPU. CUDA_CHECK(cudaSetDevice(device_id)); if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_)); + if (Get().cusparse_handle_) + CUSPARSE_CHECK(cusparseDestroy(Get().cusparse_handle_)); + if (Get().cusparse_mat_descr_) + CUSPARSE_CHECK(cusparseDestroyMatDescr(Get().cusparse_mat_descr_)); if (Get().curand_generator_) { CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_)); } CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_)); + CUSPARSE_CHECK(cusparseCreate(&Get().cusparse_handle_)); + CUSPARSE_CHECK(cusparseCreateMatDescr(&Get().cusparse_mat_descr_)); + cusparseSetMatType(Get().cusparse_mat_descr_, CUSPARSE_MATRIX_TYPE_GENERAL); + cusparseSetMatIndexBase(Get().cusparse_mat_descr_, CUSPARSE_INDEX_BASE_ZERO); CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_, @@ -234,6 +260,32 @@ const char* cublasGetErrorString(cublasStatus_t error) { return "Unknown cublas status"; } +const char* cusparseGetErrorString(cusparseStatus_t error) { + switch (error) { + case CUSPARSE_STATUS_SUCCESS: + return "CUSPARSE_STATUS_SUCCESS"; + case CUSPARSE_STATUS_NOT_INITIALIZED: + return "CUSPARSE_STATUS_NOT_INITIALIZED"; + case CUSPARSE_STATUS_ALLOC_FAILED: + return "CUSPARSE_STATUS_ALLOC_FAILED"; + case CUSPARSE_STATUS_INVALID_VALUE: + return "CUSPARSE_STATUS_INVALID_VALUE"; + case CUSPARSE_STATUS_ARCH_MISMATCH: + return "CUSPARSE_STATUS_ARCH_MISMATCH"; + case CUSPARSE_STATUS_MAPPING_ERROR: + return "CUSPARSE_STATUS_MAPPING_ERROR"; + case CUSPARSE_STATUS_EXECUTION_FAILED: + return "CUSPARSE_STATUS_EXECUTION_FAILED"; + case CUSPARSE_STATUS_INTERNAL_ERROR: + return "CUSPARSE_STATUS_INTERNAL_ERROR"; + case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; + case CUSPARSE_STATUS_ZERO_PIVOT: + return "CUSPARSE_STATUS_ZERO_PIVOT"; + } + return "Unknown CUSPARSE status"; +} + const char* curandGetErrorString(curandStatus_t error) { switch (error) { case CURAND_STATUS_SUCCESS: diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 5a286cd4691..95d8065442f 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -1,8 +1,11 @@ #include +#include "caffe/blob.hpp" +#include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/layer_factory.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/sparse_blob.hpp" #include "caffe/vision_layers.hpp" namespace caffe { @@ -155,4 +158,26 @@ REGISTER_LAYER_CREATOR(TANH, GetTanHLayer); // Layers that use their constructor as their default creator should be // registered in their corresponding cpp files. Do not registere them here. +template +Blob* GetTopBlob(const shared_ptr& param, int pos) { + const LayerParameter_LayerType& type = param->type(); + switch (type) { + case LayerParameter_LayerType_DATA_SPARSE_INPUT: + if (pos == 0) { + return new SparseBlob(); + } else { + return new Blob(); + } + default: + return new Blob(); + } + // just to suppress old compiler warnings. + return new Blob(); +} + +template Blob* GetTopBlob(const shared_ptr& param, + int pos); +template Blob* GetTopBlob(const shared_ptr& param, + int pos); + } // namespace caffe diff --git a/src/caffe/layers/data_layer_sparse_input.cpp b/src/caffe/layers/data_layer_sparse_input.cpp new file mode 100644 index 00000000000..4f1ef345dcf --- /dev/null +++ b/src/caffe/layers/data_layer_sparse_input.cpp @@ -0,0 +1,219 @@ +#include +#include +#include + +#include +#include + +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/rng.hpp" +#include "caffe/vision_layers.hpp" + +using std::string; + +namespace caffe { + +template +void* DataLayerSparseInputPrefetch(void* layer_pointer) { + CHECK(layer_pointer); + DataLayerSparseInput* layer = + static_cast*>(layer_pointer); + CHECK(layer); + vector > datums; + CHECK(layer->prefetch_data_); + + Dtype* top_label = NULL; // suppress warnings about uninitialized variables + if (layer->output_labels_) { + top_label = layer->prefetch_label_->mutable_cpu_data(); + } + + const int batch_size = layer->layer_param_.data_sparse_input_param() + .batch_size(); + const int size = layer->datum_size_; + + for (int item_id = 0; item_id < batch_size; ++item_id) { + CHECK(layer->iter_); + CHECK(layer->iter_->Valid()); + shared_ptr datum(new SparseDatum()); + + datum->ParseFromString(layer->iter_->value().ToString()); + datums.push_back(datum); + if (layer->output_labels_) { + top_label[item_id] = datum->label(); + } + // go to the next iter + layer->iter_->Next(); + if (!layer->iter_->Valid()) { + // We have reached the end. Restart from the first. + layer->iter_->SeekToFirst(); + } + } + int nn = 0; + for (int i = 0; i < batch_size; i++) { + nn += datums[i]->nn(); + } + layer->prefetch_data_->Reshape(batch_size, size, nn); + + Dtype* top_data = layer->prefetch_data_->mutable_cpu_data(); + int* indices = layer->prefetch_data_->mutable_cpu_indices(); + int* ptr = layer->prefetch_data_->mutable_cpu_ptr(); + + ptr[0] = 0; + int pos = 0; + for (int i = 0; i < batch_size; i++) { + shared_ptr d = datums[i]; + for (int k = 0; k < d->nn(); k++) { + top_data[k + pos] = d->data(k); + indices[k + pos] = d->indices(k); + } + pos += d->nn(); + ptr[i + 1] = pos; + } + return static_cast(NULL); +} + +template +DataLayerSparseInput::~DataLayerSparseInput() { + JoinPrefetchThread(); +} + +template +void DataLayerSparseInput::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + CHECK_EQ(bottom.size(), 0)<< "Data Layer takes no input blobs."; + CHECK_GE(top.size(), 1) << "Data Layer takes at least one blob as output."; + CHECK_LE(top.size(), 2) << "Data Layer takes at most two blobs as output."; + if (top.size() == 1) { + output_labels_ = false; + } else { + output_labels_ = true; + } + // Initialize the leveldb + leveldb::DB* db_temp; + leveldb::Options options; + options.create_if_missing = false; + options.max_open_files = 100; + LOG(INFO) << "Opening leveldb " + << this->layer_param_.data_sparse_input_param().source(); + leveldb::Status status = leveldb::DB::Open( + options, this->layer_param_.data_sparse_input_param().source(), &db_temp); + CHECK(status.ok()) << "Failed to open leveldb " + << this->layer_param_.data_sparse_input_param().source() << std::endl + << status.ToString(); + db_.reset(db_temp); + iter_.reset(db_->NewIterator(leveldb::ReadOptions())); + iter_->SeekToFirst(); + // Check if we would need to randomly skip a few data points + if (this->layer_param_.data_sparse_input_param().rand_skip()) { + unsigned int skip = caffe_rng_rand() % + this->layer_param_.data_sparse_input_param().rand_skip(); + LOG(INFO) << "Skipping first " << skip << " data points."; + while (skip-- > 0) { + iter_->Next(); + if (!iter_->Valid()) { + iter_->SeekToFirst(); + } + } + } + // Read a data point, and use it to initialize the top blob. + SparseDatum datum; + datum.ParseFromString(iter_->value().ToString()); + + if ( SparseBlob * sparseBlob = + dynamic_cast*>(top[0])) { + sparseBlob -> Reshape( + this->layer_param_.data_sparse_input_param().batch_size(), + datum.size(), 1); + } else { + LOG(FATAL) << "The top blob in the data layer sparse is not sparse\n"; + } + prefetch_data_.reset(new SparseBlob( + this->layer_param_.data_sparse_input_param().batch_size(), + datum.size(), 1)); + prefetch_data_copy_.reset(new SparseBlob( + this->layer_param_.data_sparse_input_param().batch_size(), + datum.size(), 1)); + + LOG(INFO) << "output data size: " << top[0]->num() << "," + << top[0]->channels() << "," << top[0]->height() << "," + << top[0]->width(); + // label + if (output_labels_) { + top[1]->Reshape( + this->layer_param_.data_sparse_input_param().batch_size(), + 1, 1, 1); + prefetch_label_.reset( + new Blob( + this->layer_param_.data_sparse_input_param().batch_size(), + 1, 1, 1)); + prefetch_label_copy_.reset( + new Blob( + this->layer_param_.data_sparse_input_param().batch_size(), + 1, 1, 1)); + } + datum_size_ = datum.size(); + + // Now, start the prefetch thread. Before calling prefetch, we make two + // cpu_data calls so that the prefetch thread does not accidentally make + // simultaneous cudaMalloc calls when the main thread is running. In some + // GPUs this seems to cause failures if we do not so. + prefetch_data_->mutable_cpu_data(); + if (output_labels_) { + prefetch_label_->mutable_cpu_data(); + } + DLOG(INFO) << "Initializing prefetch"; + CreatePrefetchThread(); + DLOG(INFO) << "Prefetch initialized."; +} + +template +void DataLayerSparseInput::CreatePrefetchThread() { + // Create the thread. + CHECK(!pthread_create(&thread_, NULL, DataLayerSparseInputPrefetch, + static_cast(this))) << "Pthread execution failed."; +} + +template +void DataLayerSparseInput::JoinPrefetchThread() { + CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed."; +} + +template +void DataLayerSparseInput::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + // First, join the thread + JoinPrefetchThread(); + // we swap the prefetch data + prefetch_data_.swap(prefetch_data_copy_); + prefetch_label_.swap(prefetch_label_copy_); + + // Start a new prefetch thread ahead of any memory transfer + CreatePrefetchThread(); + + if (SparseBlob * sparseBlob = + dynamic_cast*>(top[0])) { + sparseBlob->set_cpu_data( + const_cast(prefetch_data_copy_->cpu_data()), + const_cast(prefetch_data_copy_->cpu_indices()), + const_cast(prefetch_data_copy_->cpu_ptr()), + prefetch_data_copy_->nzz(), prefetch_data_copy_->nzz()); + } else { + LOG(FATAL)<< "The top blob in the data layer sparse is not sparse\n"; + } + if (output_labels_) { + caffe_copy(prefetch_label_copy_->count(), prefetch_label_copy_->cpu_data(), + top[1]->mutable_cpu_data()); + } +} + +#ifdef CPU_ONLY +STUB_GPU_FORWARD(DataLayerSparseInput, Forward); +#endif + +INSTANTIATE_CLASS(DataLayerSparseInput); +REGISTER_LAYER_CLASS(DATA_SPARSE_INPUT, DataLayerSparseInput); + +} // namespace caffe diff --git a/src/caffe/layers/data_layer_sparse_input.cu b/src/caffe/layers/data_layer_sparse_input.cu new file mode 100644 index 00000000000..00a76600b27 --- /dev/null +++ b/src/caffe/layers/data_layer_sparse_input.cu @@ -0,0 +1,44 @@ +#include +#include +#include + +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +using std::string; + +namespace caffe { + +template +void DataLayerSparseInput::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + // First, join the thread + JoinPrefetchThread(); + prefetch_data_.swap(prefetch_data_copy_); + prefetch_label_.swap(prefetch_label_copy_); + + // Start a new prefetch thread + CreatePrefetchThread(); + + if (SparseBlob * sparseBlob = + dynamic_cast*>(top[0])) { + sparseBlob->set_gpu_data( + const_cast(prefetch_data_copy_->gpu_data()), + const_cast(prefetch_data_copy_->gpu_indices()), + const_cast(prefetch_data_copy_->gpu_ptr()), + prefetch_data_copy_->nzz(), prefetch_data_copy_->nzz()); + } else { + LOG(FATAL)<< "The top blob in the data layer sparse is not sparse\n"; + } + + if (output_labels_) { + caffe_copy(prefetch_label_copy_->count(), prefetch_label_copy_->cpu_data(), + top[1]->mutable_gpu_data()); + } +} +INSTANTIATE_LAYER_GPU_FUNCS(DataLayerSparseInput); +} // namespace caffe diff --git a/src/caffe/layers/sparse_inner_product_layer.cpp b/src/caffe/layers/sparse_inner_product_layer.cpp new file mode 100644 index 00000000000..2ecef008752 --- /dev/null +++ b/src/caffe/layers/sparse_inner_product_layer.cpp @@ -0,0 +1,97 @@ +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { +template +void SparseInnerProductLayer::Forward_cpu( + const vector*>& bottom, + const vector*>& top) { + SparseBlob * bottomSparseBlob = + dynamic_cast*>(bottom[0]); + + if (bottomSparseBlob == 0) { // fall back to dense computation + InnerProductLayer::Forward_cpu(bottom, top); + return; + } + const Dtype* bottom_data = bottomSparseBlob->cpu_data(); + const int* bottom_indices = bottomSparseBlob->cpu_indices(); + const int* bottom_ptr = bottomSparseBlob->cpu_ptr(); + const int nzz = bottomSparseBlob->nzz(); + + Dtype* top_data = top[0]->mutable_cpu_data(); + const Dtype* weight = this->blobs_[0]->cpu_data(); + + caffe_cpu_csr_gemm(CblasNoTrans, CblasTrans, this->M_, + this->N_, + this->K_, (Dtype) 1., nzz, bottom_data, + bottom_indices, bottom_ptr, weight, + (Dtype) 0., + top_data, CblasRowMajor); + + if (this->bias_term_) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, this->M_, this->N_, 1, + (Dtype) 1., this->bias_multiplier_.cpu_data(), + this->blobs_[1]->cpu_data(), (Dtype) 1., top_data); + } +} + +template +void SparseInnerProductLayer::Backward_cpu( + const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + + SparseBlob * bottomSparseBlob = + dynamic_cast*>(bottom[0]); + // fall back to dense computation + if (bottomSparseBlob == 0) { + InnerProductLayer::Backward_cpu(top, propagate_down, bottom); + return; + } + if (this->param_propagate_down_[0]) { + // Gradient with respect to weight + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* bottom_data = bottomSparseBlob->cpu_data(); + const int* bottom_indices = bottomSparseBlob->cpu_indices(); + const int* bottom_ptr = bottomSparseBlob->cpu_ptr(); + const int nzz = bottomSparseBlob->nzz(); + caffe_cpu_csr_gemm(CblasTrans, CblasNoTrans, this->K_, + this->N_, + this->M_, (Dtype) 1., nzz, bottom_data, + bottom_indices, bottom_ptr, top_diff, + (Dtype) 0., + this->blobs_[0]->mutable_cpu_diff(), + CblasColMajor); + } + + if (this->bias_term_ && this->param_propagate_down_[1]) { + // Gradient with respect to bias + const Dtype* top_diff = top[0]->cpu_diff(); + caffe_cpu_gemv(CblasTrans, this->M_, this->N_, (Dtype) 1., + top_diff, + this->bias_multiplier_.cpu_data(), + (Dtype) 0., + this->blobs_[1]->mutable_cpu_diff()); + } + if (propagate_down[0]) { + // there is a bug in the code because this is called no matter what! + LOG(FATAL) << "propagate down not supported for sparse inner product"; + } +} + +#ifdef CPU_ONLY +STUB_GPU(SparseInnerProductLayer); +#endif + +INSTANTIATE_CLASS(SparseInnerProductLayer); +REGISTER_LAYER_CLASS(SPARSE_INNER_PRODUCT, SparseInnerProductLayer); + +} // namespace caffe + + diff --git a/src/caffe/layers/sparse_inner_product_layer.cu b/src/caffe/layers/sparse_inner_product_layer.cu new file mode 100644 index 00000000000..3b0d4b968bb --- /dev/null +++ b/src/caffe/layers/sparse_inner_product_layer.cu @@ -0,0 +1,85 @@ +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SparseInnerProductLayer::Forward_gpu( + const vector*>& bottom, + const vector*>& top) { + SparseBlob * bottomSparseBlob = + dynamic_cast*>(bottom[0]); + // fall back to dense computation + if (bottomSparseBlob == 0) { // fall back to dense computation + InnerProductLayer::Forward_gpu(bottom, top); + return; + } + const Dtype* bottom_data = bottomSparseBlob->gpu_data(); + const int* bottom_indices = bottomSparseBlob->gpu_indices(); + const int* bottom_ptr = bottomSparseBlob->gpu_ptr(); + const int nzz = bottomSparseBlob->nzz(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const Dtype* weight = this->blobs_[0]->gpu_data(); + caffe_gpu_csr_gemm(CblasNoTrans, CblasTrans, this->M_, this->N_, + this->K_, (Dtype) 1., nzz, bottom_data, + bottom_indices, bottom_ptr, weight, (Dtype) 0., + top_data, CblasRowMajor); + + if (this->bias_term_) { + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, this->M_, this->N_, 1, + (Dtype) 1., this->bias_multiplier_.gpu_data(), + this->blobs_[1]->gpu_data(), (Dtype) 1., top_data); + } +} + +template +void SparseInnerProductLayer::Backward_gpu( + const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + + SparseBlob * bottomSparseBlob = + dynamic_cast*>(bottom[0]); + // fall back to dense computation + if (bottomSparseBlob == 0) { + InnerProductLayer::Backward_gpu(top, propagate_down, bottom); + return; + } + // Gradient with respect to weight + if (this->param_propagate_down_[0]) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottomSparseBlob->gpu_data(); + const int* bottom_indices = bottomSparseBlob->gpu_indices(); + const int* bottom_ptr = bottomSparseBlob->gpu_ptr(); + const int nzz = bottomSparseBlob->nzz(); + caffe_gpu_csr_gemm(CblasTrans, CblasNoTrans, this->K_, + this->N_, + this->M_, (Dtype) 1., nzz, bottom_data, + bottom_indices, bottom_ptr, top_diff, + (Dtype) 0., + this->blobs_[0]->mutable_gpu_diff(), + CblasColMajor); + } + if (this->bias_term_ && this->param_propagate_down_[1]) { + const Dtype* top_diff = top[0]->gpu_diff(); + // Gradient with respect to bias + caffe_gpu_gemv(CblasTrans, this->M_, this->N_, (Dtype) 1., + top_diff, + this->bias_multiplier_.gpu_data(), + (Dtype) 0., + this->blobs_[1]->mutable_gpu_diff()); + } + if (propagate_down[0]) { + LOG(FATAL) << "propagate down is not supported by sparse inner product"; + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(SparseInnerProductLayer); + +} // namespace caffe diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 21ab15fd31b..47512209983 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -357,7 +357,8 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, } else { LOG(INFO) << "Input " << top_id << " -> " << blob_name; } - shared_ptr > blob_pointer(new Blob()); + shared_ptr > blob_pointer(layer_param ? + GetTopBlob(layer_param, top_id) : new Blob()); const int blob_id = blobs_.size(); blobs_.push_back(blob_pointer); blob_names_.push_back(blob_name); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 8086ad66579..cf6e3b3575d 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -30,6 +30,15 @@ message Datum { optional bool encoded = 7 [default = false]; } +//A sparce vector. Indices are the position in the vector of the non zero entries from data +message SparseDatum { + optional int32 size = 1 [default = 0]; + optional int32 nn = 2 [default = 0]; //number of non zeros entries + repeated int32 indices = 3 [packed=true]; + optional int32 label = 4; + repeated float data = 5 [packed=true]; +} + message FillerParameter { // The filler type. optional string type = 1 [default = 'constant']; @@ -206,7 +215,7 @@ message NetStateRule { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available ID: 42 (last added: exp_param) +// LayerParameter next available ID: 43 (last added: data_sparse_input_param) message LayerParameter { repeated string bottom = 2; // the name of the bottom blobs repeated string top = 3; // the name of the top blobs @@ -227,7 +236,7 @@ message LayerParameter { // line above the enum. Update the next available ID when you add a new // LayerType. // - // LayerType next available ID: 39 (last added: EXP) + // LayerType next available ID: 41 (last added: SPARSE_INNER_PRODUCT) enum LayerType { // "NONE" layer type is 0th enum element so that we don't cause confusion // by defaulting to an existent LayerType (instead, should usually error if @@ -241,6 +250,7 @@ message LayerParameter { CONTRASTIVE_LOSS = 37; CONVOLUTION = 4; DATA = 5; + DATA_SPARSE_INPUT = 39; DROPOUT = 6; DUMMY_DATA = 32; EUCLIDEAN_LOSS = 7; @@ -266,6 +276,7 @@ message LayerParameter { SILENCE = 36; SOFTMAX = 20; SOFTMAX_LOSS = 21; + SPARSE_INNER_PRODUCT = 40; SPLIT = 22; SLICE = 33; TANH = 23; @@ -305,6 +316,7 @@ message LayerParameter { optional ContrastiveLossParameter contrastive_loss_param = 40; optional ConvolutionParameter convolution_param = 10; optional DataParameter data_param = 11; + optional DataSparseInputParameter data_sparse_input_param = 42; optional DropoutParameter dropout_param = 12; optional DummyDataParameter dummy_data_param = 26; optional EltwiseParameter eltwise_param = 24; @@ -445,6 +457,19 @@ message DataParameter { optional bool mirror = 6 [default = false]; } +// Message that stores parameters used by DataLayeSparseInput +message DataSparseInputParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. Size is the number of columns in the data. + optional uint32 batch_size = 2; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the leveldb. + optional uint32 rand_skip = 3 [default = 0]; +} + // Message that stores parameters used by DropoutLayer message DropoutParameter { optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio diff --git a/src/caffe/sparse_blob.cpp b/src/caffe/sparse_blob.cpp new file mode 100644 index 00000000000..ad607d8dc13 --- /dev/null +++ b/src/caffe/sparse_blob.cpp @@ -0,0 +1,212 @@ +#include "caffe/sparse_blob.hpp" +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void SparseBlob::Reshape(const int num, const int channels, + const int nzz) { + CHECK_GE(num, 0); + CHECK_GE(channels, 0); + CHECK_GE(nzz, 0); + + const int previous_num = this->num_; + this->num_ = num; + this->channels_ = channels; + this->height_ = 1; + this->width_ = 1; + this->count_ = this->num_ * this->channels_; + if (this->count_) { + if (nzz != nzz_) { + nzz_ = nzz; + this->data_.reset(new SyncedMemory(nzz_ * sizeof(Dtype))); + indices_.reset(new SyncedMemory(nzz_ * sizeof(int))); + } + if (previous_num != num) { + ptr_.reset(new SyncedMemory((this->num_ + 1) * sizeof(int))); + } + } else { + this->data_.reset(reinterpret_cast(NULL)); + indices_.reset(reinterpret_cast(NULL)); + ptr_.reset(reinterpret_cast(NULL)); + } +} + +template +void SparseBlob::Reshape(const int num, const int channels, + const int height, const int width) { + CHECK_EQ(height, 1); + CHECK_EQ(width, 1); + Reshape(num, channels, 1); // 1 to make sure something is created +} + +template +void SparseBlob::ReshapeLike(const Blob& other) { + if (const SparseBlob* sparseBlob = + dynamic_cast*>((Blob*) (&other))) { + Reshape(sparseBlob->num(), sparseBlob->channels(), sparseBlob->nzz()); + } else { + Reshape(other.num(), other.channels(), other.height(), other.width()); + } +} + +template +SparseBlob::SparseBlob(const int num, const int channels, + const int nzz) { + nzz_ = 0; + this->num_ = 0; + Reshape(num, channels, nzz); +} + +template +void SparseBlob::set_cpu_data(Dtype* data) { + LOG(FATAL)<< "set_cpu_data is not supported"; +} + +template +void SparseBlob::set_gpu_data(Dtype* data) { + LOG(FATAL)<< "set_gpu_data is not supported"; +} + +template +void SparseBlob::set_cpu_data(Dtype* data, int* indices, int* ptr, + int nzz, int total_size) { + CHECK(data); + CHECK(indices); + CHECK(ptr); + nzz_ = nzz; + if (total_size == -1) { + total_size = nzz; + } + CHECK_GE(total_size, nzz); + this->data_->set_cpu_data(reinterpret_cast(data), + total_size * sizeof(Dtype)); + indices_->set_cpu_data(reinterpret_cast(indices), + total_size * sizeof(int)); + ptr_->set_cpu_data(reinterpret_cast(ptr), + (this->num_ + 1) * sizeof(int)); +} +template +void SparseBlob::set_gpu_data(Dtype* data, int* indices, int* ptr, + int nzz, int total_size) { + CHECK(data); + CHECK(indices); + CHECK(ptr); + nzz_ = nzz; + if (total_size == -1) { + total_size = nzz; + } + CHECK_GE(total_size, nzz); + this->data_->set_gpu_data(data, total_size * sizeof(Dtype)); + indices_->set_gpu_data(indices, total_size * sizeof(int)); + ptr_->set_gpu_data(ptr, (this->num_ + 1) * sizeof(int)); +} + +template +const Dtype* SparseBlob::cpu_diff() const { + LOG(FATAL)<< "cpu_diff is not supported"; + return NULL; +} + +template +const Dtype* SparseBlob::gpu_diff() const { + LOG(FATAL)<< "gpu_diff is not supported"; + return NULL; +} + +template +Dtype* SparseBlob::mutable_cpu_diff() { + LOG(FATAL)<< "cpu_mutable_diff is not supported"; + return NULL; +} + +template +Dtype* SparseBlob::mutable_gpu_diff() { + LOG(FATAL)<< "gpu_mutable_diff is not supported"; + return NULL; +} + +template +const int* SparseBlob::cpu_indices() const { + CHECK(indices_); + return (const int*) indices_->cpu_data(); +} + +template +const int* SparseBlob::cpu_ptr() const { + CHECK(ptr_); + return (const int*) ptr_->cpu_data(); +} + +template +const int* SparseBlob::gpu_indices() const { + CHECK(indices_); + return (const int*) indices_->gpu_data(); +} + +template +const int* SparseBlob::gpu_ptr() const { + CHECK(ptr_); + return (const int*) ptr_->gpu_data(); +} + +template +int* SparseBlob::mutable_cpu_indices() { + CHECK(indices_); + return reinterpret_cast(indices_->mutable_cpu_data()); +} + +template +int* SparseBlob::mutable_cpu_ptr() { + CHECK(ptr_); + return reinterpret_cast(ptr_->mutable_cpu_data()); +} + +template +int* SparseBlob::mutable_gpu_indices() { + CHECK(indices_); + return reinterpret_cast(indices_->mutable_gpu_data()); +} + +template +int* SparseBlob::mutable_gpu_ptr() { + CHECK(ptr_); + return reinterpret_cast(ptr_->mutable_gpu_data()); +} + +template +void SparseBlob::ShareData(const Blob& other) { + LOG(FATAL)<< "ShareData is not supported"; +} + +template +void SparseBlob::ShareDiff(const Blob& other) { + LOG(FATAL)<< "ShareDiff is not supported"; +} + +template +void SparseBlob::Update() { + LOG(FATAL)<< "Update is not supported"; +} + +template +void SparseBlob::CopyFrom(const Blob& source, bool copy_diff, + bool reshape) { + LOG(FATAL)<< "CopyFrom is not supported"; +} + +template +void SparseBlob::FromProto(const BlobProto& proto) { + LOG(FATAL)<< "FromProto is not supported"; +} + +template +void SparseBlob::ToProto(BlobProto* proto, bool write_diff) const { + LOG(FATAL)<< "ToProto is not supported"; +} +INSTANTIATE_CLASS(SparseBlob); +} // namespace caffe + diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 7617ccfb27f..2b24e15c709 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -7,15 +7,7 @@ namespace caffe { SyncedMemory::~SyncedMemory() { - if (cpu_ptr_ && own_cpu_data_) { - CaffeFreeHost(cpu_ptr_); - } - -#ifndef CPU_ONLY - if (gpu_ptr_) { - CUDA_CHECK(cudaFree(gpu_ptr_)); - } -#endif // CPU_ONLY + clear_data(); } inline void SyncedMemory::to_cpu() { @@ -49,12 +41,14 @@ inline void SyncedMemory::to_gpu() { switch (head_) { case UNINITIALIZED: CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; caffe_gpu_memset(size_, 0, gpu_ptr_); head_ = HEAD_AT_GPU; break; case HEAD_AT_CPU: if (gpu_ptr_ == NULL) { CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; } caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_); head_ = SYNCED; @@ -68,21 +62,58 @@ inline void SyncedMemory::to_gpu() { #endif } +void SyncedMemory::clear_data() { + if (cpu_ptr_ && own_cpu_data_) { + CaffeFreeHost(cpu_ptr_); + cpu_ptr_ = NULL; + } +#ifndef CPU_ONLY + if (gpu_ptr_ && own_gpu_data_) { + CUDA_CHECK(cudaFree(gpu_ptr_)); + gpu_ptr_ = NULL; + } +#endif // CPU_ONLY + head_ = UNINITIALIZED; +} + const void* SyncedMemory::cpu_data() { to_cpu(); return (const void*)cpu_ptr_; } -void SyncedMemory::set_cpu_data(void* data) { +void SyncedMemory::set_cpu_data(void* data, int size) { CHECK(data); - if (own_cpu_data_) { - CaffeFreeHost(cpu_ptr_); + if (size != -1 && size_ != size) { + clear_data(); + size_ = size; + } + if (cpu_ptr_ && own_cpu_data_) { + CaffeFreeHost(cpu_ptr_); } cpu_ptr_ = data; head_ = HEAD_AT_CPU; own_cpu_data_ = false; } +void SyncedMemory::set_gpu_data(void* data, int size) { +#ifndef CPU_ONLY + CHECK(data); + if (size != -1 && size_ != size) { + clear_data(); + size_ = size; + } + if (gpu_ptr_ && own_gpu_data_) { + CUDA_CHECK(cudaFree(gpu_ptr_)); + } + + gpu_ptr_ = data; + head_ = HEAD_AT_GPU; + own_gpu_data_ = false; +#else + NO_GPU; +#endif +} + const void* SyncedMemory::gpu_data() { #ifndef CPU_ONLY to_gpu(); diff --git a/src/caffe/test/test_data_layer_sparse.cpp b/src/caffe/test/test_data_layer_sparse.cpp new file mode 100644 index 00000000000..c7ff933218d --- /dev/null +++ b/src/caffe/test/test_data_layer_sparse.cpp @@ -0,0 +1,177 @@ +#include +#include + +#include "leveldb/db.h" + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/sparse_blob.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +using std::string; +using std::stringstream; + +namespace caffe { + +template +class DataLayerSparseTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + DataLayerSparseTest() + : filename_(new string(tmpnam(NULL))), + blob_top_data_(new SparseBlob()), + blob_top_label_(new Blob()) { + } + virtual void SetUp() { + blob_top_vec_.push_back(blob_top_data_); + blob_top_vec_.push_back(blob_top_label_); + } + + void FillLevelDB() { + LOG(INFO)<< "Using temporary leveldb " << *filename_; + leveldb::DB* db; + leveldb::Options options; + options.error_if_exists = true; + options.create_if_missing = true; + leveldb::Status status = + leveldb::DB::Open(options, filename_->c_str(), &db); + CHECK(status.ok()); + for (int i = 0; i < 6; ++i) { + SparseDatum datum; + datum.set_label(i); + datum.set_size(6); + datum.set_nn(i+1); + for (int j = 0; j < i+1; ++j) { + datum.mutable_data()->Add(j+1); + datum.mutable_indices()->Add(j); + } + stringstream ss; + ss << i; + db->Put(leveldb::WriteOptions(), ss.str(), datum.SerializeAsString()); + } + delete db; + } + + void TestRead() { + LayerParameter param; + DataSparseInputParameter* data_param = + param.mutable_data_sparse_input_param(); + data_param->set_batch_size(6); + data_param->set_source(filename_->c_str()); + DataLayerSparseInput layer(param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_data_->num(), 6); + EXPECT_EQ(blob_top_data_->channels(), 6); + EXPECT_EQ(blob_top_data_->height(), 1); + EXPECT_EQ(blob_top_data_->width(), 1); + EXPECT_EQ(blob_top_label_->num(), 6); + EXPECT_EQ(blob_top_label_->channels(), 1); + EXPECT_EQ(blob_top_label_->height(), 1); + EXPECT_EQ(blob_top_label_->width(), 1); + + for (int iter = 0; iter < 100; ++iter) { + layer.Forward(blob_bottom_vec_, blob_top_vec_); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); + } + EXPECT_EQ(0, blob_top_data_->cpu_ptr()[0]); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ((i+1) * (i+2)/2, + blob_top_data_->cpu_ptr()[i+1]) << + "debug ptr: iter " << iter << " i " << i; + for (int j = 0; j < i; ++j) { + EXPECT_EQ(j+1, blob_top_data_-> + cpu_data()[blob_top_data_->cpu_ptr()[i]+j]) + << "debug data: iter " << iter << " i " << i << " j " << j; + EXPECT_EQ(j, blob_top_data_-> + cpu_indices()[blob_top_data_->cpu_ptr()[i]+j]) + << "debug indices: iter " << iter << " i " << i << " j " << j; + } + } + } + } + void TestRead2() { + LayerParameter param; + DataSparseInputParameter* data_param = + param.mutable_data_sparse_input_param(); + // half the previous batch size to alternate between 2 different dataset + data_param->set_batch_size(3); + data_param->set_source(filename_->c_str()); + DataLayerSparseInput layer(param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_data_->num(), 3); + EXPECT_EQ(blob_top_data_->channels(), 6); + EXPECT_EQ(blob_top_data_->height(), 1); + EXPECT_EQ(blob_top_data_->width(), 1); + EXPECT_EQ(blob_top_label_->num(), 3); + EXPECT_EQ(blob_top_label_->channels(), 1); + EXPECT_EQ(blob_top_label_->height(), 1); + EXPECT_EQ(blob_top_label_->width(), 1); + + int delta = 0; + for (int iter = 0; iter < 100; ++iter) { + layer.Forward(blob_bottom_vec_, blob_top_vec_); + if (iter % 2) { + delta = 3; + } else { + delta= 0; + } + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(i+delta, blob_top_label_->cpu_data()[i]); + } + + EXPECT_EQ(0, blob_top_data_->cpu_ptr()[0]); + if (delta == 0) { + EXPECT_EQ(1, blob_top_data_->cpu_ptr()[1]); + EXPECT_EQ(3, blob_top_data_->cpu_ptr()[2]); + EXPECT_EQ(6, blob_top_data_->cpu_ptr()[3]); + } else { + EXPECT_EQ(4, blob_top_data_->cpu_ptr()[1]); + EXPECT_EQ(9, blob_top_data_->cpu_ptr()[2]); + EXPECT_EQ(15, blob_top_data_->cpu_ptr()[3]); + } + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < i+delta; ++j) { + EXPECT_EQ(j+1, + blob_top_data_->cpu_data()[blob_top_data_->cpu_ptr()[i]+j]) + << "debug data: iter " << iter << " i " << i << " j " << j; + EXPECT_EQ(j, + blob_top_data_->cpu_indices()[blob_top_data_->cpu_ptr()[i]+j]) + << "debug indices: iter " << iter << " i " << i << " j " << j; + } + } + } + } + + virtual ~DataLayerSparseTest() { + delete blob_top_data_; + delete blob_top_label_; + } + + shared_ptr filename_; + SparseBlob* const blob_top_data_; + Blob* const blob_top_label_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(DataLayerSparseTest, TestDtypesAndDevices); + +TYPED_TEST(DataLayerSparseTest, TestReadLevelDB) { + this->FillLevelDB(); + this->TestRead(); +} + +TYPED_TEST(DataLayerSparseTest, TestReadLevelDB2) { + this->FillLevelDB(); + this->TestRead2(); +} + +} // namespace caffe diff --git a/src/caffe/test/test_math_functions.cpp b/src/caffe/test/test_math_functions.cpp index 667f744bdd7..4fbb558aeeb 100644 --- a/src/caffe/test/test_math_functions.cpp +++ b/src/caffe/test/test_math_functions.cpp @@ -3,12 +3,13 @@ #include #include // for std::fabs #include // for rand_r - +#include #include "gtest/gtest.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/filler.hpp" +#include "caffe/util/benchmark.hpp" #include "caffe/util/math_functions.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -228,8 +229,690 @@ TYPED_TEST(MathFunctionsTest, TestCopyGPU) { EXPECT_EQ(bottom_data[i], top_data[i]); } } +#endif + +template +class CsrFunctionsGenTest : public ::testing::Test { + protected: + CsrFunctionsGenTest() + : A_(), + indices_(), + ptr_(), + B_(), + C_(), + M(0), + N(0), + K(0), + NZZ(0), + PTR_SIZE(0), + TransA(CblasNoTrans), + TransB(CblasNoTrans), + alpha(1.0), + beta(0.0), + orderC(CblasRowMajor) { + } + + virtual void SetUp(int m, int n, int k, int nzz, int ptr_size) { + M = m; + N = n; + K = k; + NZZ = nzz; + PTR_SIZE = ptr_size; + + A_.reset(new SyncedMemory(nzz * sizeof(Dtype))); + indices_.reset(new SyncedMemory(nzz * sizeof(int))); + ptr_.reset(new SyncedMemory(ptr_size * sizeof(int))); + B_.reset(new SyncedMemory(K * N * sizeof(Dtype))); + C_.reset(new SyncedMemory(M * N * sizeof(Dtype))); + } + + virtual void run(bool isCpu, int times = 1) { + if (isCpu) { + Timer timer; + timer.Start(); + for (int t = 0; t < times; t++) { + caffe_cpu_csr_gemm(TransA, TransB, M, N, K, alpha, NZZ, cpu_A(), + cpu_indices(), cpu_ptr(), cpu_B(), beta, cpu_C(), + orderC); + } + std::cout << "Total Time for CSR CPU gemm M:" << M << " N: " << N + << " K: " << K << " transA: " << TransA << " transB: " << TransB + << " orderC: " << orderC << " equal to " + << (timer.MilliSeconds() / times) + << " milli seconds.. Time per M ops: " + << timer.MilliSeconds() / (times * NZZ * N / 1e6) + << " milli seconds\n"; + } else { +#ifndef CPU_ONLY + Dtype* agpu = gpu_A(); + int* indicesgpu = gpu_indices(); + int* ptrgpu = gpu_ptr(); + Dtype* bgpu = gpu_B(); + Dtype* cgpu = gpu_C(); + Timer timer; + timer.Start(); + for (int t = 0; t < times; t++) { + caffe_gpu_csr_gemm(TransA, TransB, M, N, K, alpha, NZZ, agpu, + indicesgpu, ptrgpu, bgpu, beta, cgpu, orderC); + } + cudaDeviceSynchronize(); + std::cout << "Total Time for CSR GPU gemm M:" << M << " N: " << N + << " K: " << K << " transA: " << TransA << " transB: " << TransB + << " orderC: " << orderC << " equal to " + << (timer.MilliSeconds() / times) + << " milli seconds. Time per M ops: " + << timer.MilliSeconds() / (times * NZZ * N / 1e6) + << " milli seconds\n"; +#else + +#endif + } + } + + void setA(Dtype A_data[], int A_indices[], int A_ptr[]) { + Dtype* am = cpu_A(); + int* aindices = cpu_indices(); + int* aptr = cpu_ptr(); + + for (int i = 0; i < NZZ; i++) { + am[i] = A_data[i]; + aindices[i] = A_indices[i]; + } + for (int i = 0; i < PTR_SIZE; i++) { + aptr[i] = A_ptr[i]; + } + } + + void setB(Dtype B_data[]) { + Dtype* bm = cpu_B(); + for (int i = 0; i < (K * N); i++) { + bm[i] = B_data[i]; + } + } + void setC(Dtype C_data[]) { + Dtype* cm = cpu_C(); + for (int i = 0; i < (M * N); i++) { + cm[i] = C_data[i]; + } + } + void checkC(Dtype C_check[]) { + Dtype* cm = cpu_C(); + for (int i = 0; i < (M * N); i++) { + EXPECT_EQ(cm[i], C_check[i]); + } + } + + Dtype* cpu_A() { + CHECK(A_); + return reinterpret_cast(A_->mutable_cpu_data()); + } + Dtype* gpu_A() { + CHECK(A_); + return reinterpret_cast(A_->mutable_gpu_data()); + } + + Dtype* cpu_B() { + CHECK(B_); + return reinterpret_cast(B_->mutable_cpu_data()); + } + Dtype* gpu_B() { + CHECK(B_); + return reinterpret_cast(B_->mutable_gpu_data()); + } + + Dtype* cpu_C() { + CHECK(C_); + return reinterpret_cast(C_->mutable_cpu_data()); + } + Dtype* gpu_C() { + CHECK(C_); + return reinterpret_cast(C_->mutable_gpu_data()); + } + + int* cpu_indices() { + CHECK(indices_); + return reinterpret_cast(indices_->mutable_cpu_data()); + } + int* gpu_indices() { + CHECK(indices_); + return reinterpret_cast(indices_->mutable_gpu_data()); + } + + int* cpu_ptr() { + CHECK(ptr_); + return reinterpret_cast(ptr_->mutable_cpu_data()); + } + int* gpu_ptr() { + CHECK(ptr_); + return reinterpret_cast(ptr_->mutable_gpu_data()); + } + + void random_csr(int M, int N, int nzz_per_row, Dtype* A, int* indices, + int* ptr) { + srand(0); + ptr[0] = 0; + for (int row = 0; row < M; row++) { + ptr[row+1] = nzz_per_row * (row+1); + for (int pos = 0; pos < nzz_per_row; pos++) { + int col = caffe_rng_rand() % N; + indices[row * nzz_per_row + pos] = col; + A[row * nzz_per_row + pos] = + static_cast (caffe_rng_rand()) / + static_cast (RAND_MAX); + } + } + } + + void random_fill(int size, Dtype* X) { + srand(0); + for (int pos = 0; pos < size; pos++) { + X[pos] = static_cast(caffe_rng_rand()) / + static_cast(RAND_MAX); + } + } + + void test_speed_forward(int batch_size, int features, int nzz_per_row, + int classes) { + Dtype* A = new Dtype[batch_size * nzz_per_row]; + int* indices = new int[batch_size * nzz_per_row]; + int* ptr = new int[batch_size + 1]; + Dtype* B = new Dtype[features * classes]; + Dtype* C = new Dtype[batch_size * classes]; + this->random_csr(batch_size, features, nzz_per_row, A, indices, ptr); + this->random_fill(features * classes, B); + this->random_fill(batch_size * classes, C); + + this->alpha = 1.0; + this->beta = 1.0; + this->SetUp(batch_size, classes, features, batch_size * nzz_per_row, + batch_size + 1); + this->TransA = CblasNoTrans; + this->TransB = CblasTrans; + this->orderC = CblasRowMajor; + + this->setA(A, indices, ptr); + this->setB(B); + this->setC(C); + this->run(true, 100); + + this->setC(C); +#ifndef CPU_ONLY + this->run(false, 100); +#else +#endif + delete A; + delete indices; + delete ptr; + delete B; + delete C; + } + + void test_speed_backward(int batch_size, int features, int nzz_per_row, + int classes) { + Dtype* A = new Dtype[batch_size * nzz_per_row]; + int* indices = new int[batch_size * nzz_per_row]; + int* ptr = new int[batch_size + 1]; + Dtype* B = new Dtype[batch_size * classes]; + Dtype* C = new Dtype[features * classes]; + this->random_csr(batch_size, features, nzz_per_row, A, indices, ptr); + this->random_fill(batch_size * classes, B); + this->random_fill(features * classes, C); + + this->alpha = 1.0; + this->beta = 1.0; + this->SetUp(features, classes, batch_size, batch_size * nzz_per_row, + batch_size + 1); + this->TransA = CblasTrans; + this->TransB = CblasNoTrans; + this->orderC = CblasColMajor; + + this->setA(A, indices, ptr); + this->setB(B); + this->setC(C); + this->run(true, 100); + + this->setC(C); +#ifndef CPU_ONLY + this->run(false, 100); +#else +#endif + + delete A; + delete indices; + delete ptr; + delete B; + delete C; + } + + shared_ptr A_; + shared_ptr indices_; + shared_ptr ptr_; + shared_ptr B_; + shared_ptr C_; + int M; + int N; + int K; + int NZZ; + int PTR_SIZE; + + CBLAS_TRANSPOSE TransA; + CBLAS_TRANSPOSE TransB; + Dtype alpha; + Dtype beta; + CBLAS_ORDER orderC; +}; + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(CsrFunctionsGenTest, Dtypes); + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm1) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {0.0, 0.0, 0.0, 0.0}; +TypeParam CCheck[] = {16.0, 25.0, 15.0, 24.0}; +this->alpha = 1.0; +this->beta = 1.0; +this->SetUp(2, 2, 3, 3, 3); + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); + +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm2) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 3.0, 4.0}; +TypeParam CCheck[] = {17.0, 27.0, 18.0, 28.0}; +this->alpha = 1.0; +this->beta = 1.0; +this->SetUp(2, 2, 3, 3, 3); +this->TransA = CblasNoTrans; +this->TransB = CblasNoTrans; +this->orderC = CblasRowMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm3) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.3, 3.0, 4.0}; +TypeParam CCheck[] = {16.0, 25.0, 15.0, 24.0}; +this->alpha = 1.0; +this->beta = 0.0; +this->SetUp(2, 2, 3, 3, 3); + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm4) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {0.0, 0.0, 0.0, 0.0}; +TypeParam CCheck[] = {16.0, 15.0, 25.0, 24.0}; +this->alpha = 1.0; +this->beta = 1.0; +this->SetUp(2, 2, 3, 3, 3); +this->TransA = CblasNoTrans; +this->TransB = CblasNoTrans; +this->orderC = CblasColMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm5) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 0.0, 0.0}; +TypeParam CCheck[] = {17.0, 17.0, 25.0, 24.0}; +this->alpha = 1.0; +this->beta = 1.0; +this->SetUp(2, 2, 3, 3, 3); +this->TransA = CblasNoTrans; +this->TransB = CblasNoTrans; +this->orderC = CblasColMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm6) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 3.0, 0.0}; +TypeParam CCheck[] = {16.0, 15.0, 25.0, 24.0}; +this->alpha = 1.0; +this->beta = 0.0; +this->SetUp(2, 2, 3, 3, 3); +this->TransA = CblasNoTrans; +this->TransB = CblasNoTrans; +this->orderC = CblasColMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm7) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {0.0, 0.0, 0.0, 0.0}; +TypeParam CCheck[] = {32.0, 50.0, 30.0, 48.0}; +this->alpha = 2.0; +this->beta = 1.0; +this->SetUp(2, 2, 3, 3, 3); + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm8) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 3.0, 0.0}; +TypeParam CCheck[] = {31.0, 58.0, 51.0, 36.0}; +this->alpha = 2.0; +this->beta = 3.0; +this->SetUp(2, 2, 3, 3, 3); +this->TransA = CblasNoTrans; +this->TransB = CblasTrans; +this->orderC = CblasRowMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm9) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 3.0, 0.0}; +TypeParam CCheck[] = {31.0, 48.0, 61.0, 36.0}; +this->alpha = 2.0; +this->beta = 3.0; +this->SetUp(2, 2, 3, 3, 3); +this->TransA = CblasNoTrans; +this->TransB = CblasTrans; +this->orderC = CblasColMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm10) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; +TypeParam CCheck[] = {11.0, 20.0, 19.0, 48.0, 36.0, 54.0, 16.0, 28.0, 20.0}; +this->alpha = 2.0; +this->beta = 3.0; +this->SetUp(3, 3, 2, 3, 3); +this->TransA = CblasTrans; +this->TransB = CblasNoTrans; +this->orderC = CblasRowMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else #endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm11) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; +TypeParam CCheck[] = {11.0, 54.0, 25.0, 14.0, 36.0, 28.0, 10.0, 54.0, 20.0}; +this->alpha = 2.0; +this->beta = 3.0; +this->SetUp(3, 3, 2, 3, 3); +this->TransA = CblasTrans; +this->TransB = CblasNoTrans; +this->orderC = CblasColMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm12) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; +TypeParam CCheck[] = {11.0, 16.0, 21.0, 42.0, 48.0, 54.0, 16.0, 20.0, 24.0}; +this->alpha = 2.0; +this->beta = 3.0; +this->SetUp(3, 3, 2, 3, 3); +this->TransA = CblasTrans; +this->TransB = CblasTrans; +this->orderC = CblasRowMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemm13) { +TypeParam A[] = {1.0, 2.0, 3.0}; +int indices[] = {0, 2, 1}; +int ptr[] = {0, 2, 3}; +TypeParam B[] = {4.0, 7.0, 5.0, 8.0, 6.0, 9.0}; +TypeParam C[] = {1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; +TypeParam CCheck[] = {11.0, 48.0, 25.0, 10.0, 48.0, 20.0, 12.0, 54.0, 24.0}; +this->alpha = 2.0; +this->beta = 3.0; +this->SetUp(3, 3, 2, 3, 3); +this->TransA = CblasTrans; +this->TransB = CblasTrans; +this->orderC = CblasColMajor; + +this->setA(A, indices, ptr); +this->setB(B); +this->setC(C); +this->run(true); +this->checkC(CCheck); +this->setC(C); +#ifndef CPU_ONLY +this->run(false); +this->checkC(CCheck); +#else + +#endif +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemmSpeedForward) { +std::vector batch_size; +std::vector features; +std::vector nzz_per_row; +std::vector classes; + +batch_size.push_back(64); +batch_size.push_back(128); +features.push_back(10000); +nzz_per_row.push_back(200); +classes.push_back(2); +classes.push_back(10); +classes.push_back(100); + +for (int ba = 0; ba < batch_size.size(); ba++) { + for (int f = 0; f < features.size(); f++) { + for (int nr = 0; nr < nzz_per_row.size(); nr++) { + for (int c = 0; c < classes.size(); c++) { + this->test_speed_forward(batch_size[ba], + features[f], nzz_per_row[nr], classes[c]); + } + } + } +} +} + +TYPED_TEST(CsrFunctionsGenTest, TestCsrGemmSpeedBackward) { +std::vector batch_size; +std::vector features; +std::vector nzz_per_row; +std::vector classes; + +batch_size.push_back(64); +batch_size.push_back(128); +features.push_back(10000); +nzz_per_row.push_back(200); +classes.push_back(2); +classes.push_back(10); +classes.push_back(100); + +for (int ba = 0; ba < batch_size.size(); ba++) { + for (int f = 0; f < features.size(); f++) { + for (int nr = 0; nr < nzz_per_row.size(); nr++) { + for (int c = 0; c < classes.size(); c++) { + this->test_speed_backward(batch_size[ba], + features[f], nzz_per_row[nr], classes[c]); + } + } + } +} +} + + } // namespace caffe diff --git a/src/caffe/test/test_sparse_inner_product_layer.cpp b/src/caffe/test/test_sparse_inner_product_layer.cpp new file mode 100644 index 00000000000..a856706bde3 --- /dev/null +++ b/src/caffe/test/test_sparse_inner_product_layer.cpp @@ -0,0 +1,133 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +#ifndef CPU_ONLY +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; +#endif + +template +class SparseInnerProductLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + SparseInnerProductLayerTest() + : blob_bottom_(new SparseBlob(2, 3, 5)), + blob_top_(new Blob()) { + // fill the values + Dtype* data = blob_bottom_->mutable_cpu_data(); + for (int i = 0; i < 4; i++) { + data[i] = (Dtype)1.; + } + data[4] = (Dtype) 0.; + + int* indices = blob_bottom_->mutable_cpu_indices(); + for (int i = 0; i < 5; i++) { + indices[i] = i % 3; + } + int* ptr = blob_bottom_->mutable_cpu_ptr(); + ptr[0] = 0; + ptr[1] = 2; + ptr[2] = 5; + + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~SparseInnerProductLayerTest() { + LOG(INFO) << "deleting sparse inner product layer test"; + delete blob_bottom_; + delete blob_top_; + } + SparseBlob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(SparseInnerProductLayerTest, TestDtypesAndDevices); + +TYPED_TEST(SparseInnerProductLayerTest, TestSetUp) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(10); + shared_ptr > layer( + new SparseInnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->height(), 1); + EXPECT_EQ(this->blob_top_->width(), 1); + EXPECT_EQ(this->blob_top_->channels(), 10); +} + +TYPED_TEST(SparseInnerProductLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + bool IS_VALID_CUDA = false; +#ifndef CPU_ONLY + IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2; +#endif + if (Caffe::mode() == Caffe::CPU || + sizeof(Dtype) == 4 || IS_VALID_CUDA) { + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(10); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + inner_product_param->mutable_weight_filler()->set_min(1); + inner_product_param->mutable_weight_filler()->set_max(2); + inner_product_param->mutable_bias_filler()->set_type("uniform"); + inner_product_param->mutable_bias_filler()->set_min(0); + inner_product_param->mutable_bias_filler()->set_max(0); + shared_ptr > layer( + new SparseInnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + for (int i = 0; i < count; ++i) { + EXPECT_GE(data[i], 2.); + EXPECT_LE(data[i], 4.); + } + } else { + LOG(ERROR) << "Skipping test due to old architecture."; + } +} + +TYPED_TEST(SparseInnerProductLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + bool IS_VALID_CUDA = false; +#ifndef CPU_ONLY + IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2; +#endif + if (Caffe::mode() == Caffe::CPU || + sizeof(Dtype) == 4 || IS_VALID_CUDA) { + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(10); + inner_product_param->mutable_weight_filler()->set_type("gaussian"); + inner_product_param->mutable_bias_filler()->set_type("gaussian"); + inner_product_param->mutable_bias_filler()->set_min(1); + inner_product_param->mutable_bias_filler()->set_max(2); + SparseInnerProductLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, -2); + } else { + LOG(ERROR) << "Skipping test due to old architecture."; + } +} + +} // namespace caffe diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 13e17be582b..1448dd3ae7f 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -4,6 +4,7 @@ #include #include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" #include "caffe/util/math_functions.hpp" #include "caffe/util/rng.hpp" @@ -31,6 +32,131 @@ void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, ldb, beta, C, N); } +template +void caffe_cpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const Dtype alpha, const int nzz, + const Dtype* A, const int* indices, const int* ptr, + const Dtype* B, const Dtype beta, Dtype* C, + const CBLAS_ORDER orderC) { + if (TransA == CblasNoTrans) { // CSR + caffe_scal(M * N, beta, C); + if (orderC == CblasRowMajor) { + if (TransB == CblasNoTrans) { + for (int rowA = 0; rowA < M; rowA++) { + const int begin = ptr[rowA]; + const int end = ptr[rowA + 1]; + Dtype* CrowA = C + (N * rowA); + for (int pos = begin; pos < end; pos++) { + const Dtype* BcolAN = B + (indices[pos] * N); + const Dtype AatPos = alpha * A[pos]; + caffe_axpy(N, AatPos, BcolAN, CrowA, 1, 1); + } + } + } else { + for (int rowA = 0; rowA < M; rowA++) { + const int begin = ptr[rowA]; + const int end = ptr[rowA + 1]; + Dtype* CrowA = C + (N * rowA); + for (int pos = begin; pos < end; pos++) { + const Dtype AatPos = alpha * A[pos]; + const Dtype* BcolA = B + indices[pos]; + caffe_axpy(N, AatPos, BcolA, CrowA, K, 1); + } + } + } + } else { + if (TransB == CblasNoTrans) { + for (int rowA = 0; rowA < M; rowA++) { + const int begin = ptr[rowA]; + const int end = ptr[rowA + 1]; + Dtype* CrowA = C + rowA; + for (int pos = begin; pos < end; pos++) { + const Dtype* BcolAN = B + (indices[pos] * N); + const Dtype AatPos = alpha * A[pos]; + caffe_axpy(N, AatPos, BcolAN, CrowA, 1, M); + } + } + } else { + for (int rowA = 0; rowA < M; rowA++) { + const int begin = ptr[rowA]; + const int end = ptr[rowA + 1]; + Dtype* CrowA = C + rowA; + for (int pos = begin; pos < end; pos++) { + const Dtype* BcolA = B + indices[pos]; + const Dtype AatPos = alpha * A[pos]; + caffe_axpy(N, AatPos, BcolA, CrowA, K, M); + } + } + } + } + } else { // A is CSC + caffe_scal(M * N, beta, C); + if (orderC == CblasRowMajor) { + if (TransB == CblasNoTrans) { + for (int colA = 0; colA < K; colA++) { + const int begin = ptr[colA]; + const int end = ptr[colA + 1]; + const Dtype* BColAN = B + (colA * N); + for (int pos = begin; pos < end; pos++) { + caffe_axpy(N, A[pos] * alpha, BColAN, + C + (indices[pos] * N), 1, 1); + } + } + } else { + for (int colA = 0; colA < K; colA++) { + const int begin = ptr[colA]; + const int end = ptr[colA + 1]; + const Dtype* BColA = B + colA; + for (int pos = begin; pos < end; pos++) { + caffe_axpy(N, A[pos] * alpha, BColA, C + (indices[pos] * N), + K, 1); + } + } + } + } else { + if (TransB == CblasNoTrans) { + for (int colA = 0; colA < K; colA++) { + const int begin = ptr[colA]; + const int end = ptr[colA + 1]; + const Dtype* BColAN = B + (colA * N); + for (int pos = begin; pos < end; pos++) { + caffe_axpy(N, A[pos] * alpha, BColAN, C + indices[pos], 1, M); + } + } + + } else { + for (int colA = 0; colA < K; colA++) { + const int begin = ptr[colA]; + const int end = ptr[colA + 1]; + const Dtype* BColA = B + colA; + for (int pos = begin; pos < end; pos++) { + caffe_axpy(N, A[pos] * alpha, BColA, C + indices[pos], K, M); + } + } + } + } + } +} + +template void caffe_cpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, + const float alpha, const int nzz, + const float* A, const int* indices, + const int* ptr, const float* B, + const float beta, float* C, + const CBLAS_ORDER orderC); + +template void caffe_cpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, + const double alpha, const int nzz, + const double* A, const int* indices, + const int* ptr, const double* B, + const double beta, double* C, + const CBLAS_ORDER orderC); + template <> void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha, const float* A, const float* x, @@ -47,11 +173,15 @@ void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, template <> void caffe_axpy(const int N, const float alpha, const float* X, - float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); } + float* Y, const int ldx, const int ldy) { + cblas_saxpy(N, alpha, X, ldx, Y, ldy); +} template <> void caffe_axpy(const int N, const double alpha, const double* X, - double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); } + double* Y, const int ldx, const int ldy) { + cblas_daxpy(N, alpha, X, ldx, Y, ldy); +} template void caffe_set(const int N, const Dtype alpha, Dtype* Y) { diff --git a/src/caffe/util/math_functions.cu b/src/caffe/util/math_functions.cu index 43e65eb9a69..ea45781d444 100644 --- a/src/caffe/util/math_functions.cu +++ b/src/caffe/util/math_functions.cu @@ -10,6 +10,8 @@ #include "caffe/common.hpp" #include "caffe/util/math_functions.hpp" +#define THREADS_PER_BLOCK_CSR 32 + namespace caffe { template <> @@ -441,4 +443,293 @@ void caffe_gpu_rng_gaussian(const int n, const double mu, const double sigma, curandGenerateNormalDouble(Caffe::curand_generator(), r, n, mu, sigma)); } +template +__device__ void caffe_gpu_csr_gemm_kernel_core(const int M, const int N, + const int K, const Dtype alpha, + int nzz, const Dtype* A, + const int* indices, + const int* ptr, const Dtype* B, + const int ldb1, const int ldb2, + const Dtype beta, Dtype* C, + const int ldc1, const int ldc2) { + __shared__ volatile Dtype sums[THREADS_PER_BLOCK_CSR * 2]; + + for (int rowA = blockIdx.x; rowA < M; rowA += gridDim.x) { + const int begin = ptr[rowA]; + const int end = ptr[rowA + 1]; + const int offset_c_part = rowA * ldc1; + for (int colC = blockIdx.y; colC < N; colC += gridDim.y) { + Dtype sum = 0.0; + const int offset_b_part = colC * ldb2; + for (int pos = begin + threadIdx.x; pos < end; pos += + THREADS_PER_BLOCK_CSR) { + const int colA = indices[pos]; + sum += A[pos] * B[colA * ldb1 + offset_b_part]; + } + sums[threadIdx.x] = sum; + __syncthreads(); + + /* hardcoded reduction for 32 threads */ + sums[threadIdx.x] += sums[threadIdx.x + 16]; + sums[threadIdx.x] += sums[threadIdx.x + 8]; + sums[threadIdx.x] += sums[threadIdx.x + 4]; + sums[threadIdx.x] += sums[threadIdx.x + 2]; + sums[threadIdx.x] += sums[threadIdx.x + 1]; + + if (threadIdx.x == 0) { + const int offsetC = offset_c_part + colC * ldc2; + C[offsetC] = beta * C[offsetC] + alpha * sums[0]; + } + } + } +} + +template +__global__ void caffe_gpu_csr_gemm_kernel(const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, + const Dtype alpha, int nzz, + const Dtype* A, const int* indices, + const int* ptr, const Dtype* B, + const Dtype beta, Dtype* C, + const CBLAS_ORDER orderC) { + if (orderC == CblasRowMajor) { + if (TransB == CblasNoTrans) { + caffe_gpu_csr_gemm_kernel_core(M, N, K, alpha, nzz, A, indices, ptr, B, N, + 1, beta, C, N, 1); + } else { + caffe_gpu_csr_gemm_kernel_core(M, N, K, alpha, nzz, A, indices, ptr, B, 1, + K, beta, C, N, 1); + } + } else { + if (TransB == CblasNoTrans) { + caffe_gpu_csr_gemm_kernel_core(M, N, K, alpha, nzz, A, indices, ptr, B, N, + 1, beta, C, 1, M); + } else { + caffe_gpu_csr_gemm_kernel_core(M, N, K, alpha, nzz, A, indices, ptr, B, 1, + K, beta, C, 1, M); + } + } +} + +template +__device__ void caffe_gpu_csr_rank1_update_kernel_core(const int M, const int N, + const Dtype alpha, + const Dtype* A, + const int* indices, + const int* ptr, + const Dtype* B, int ldb, + Dtype* C, const int ldc1, + const int ldc2) { + const int begin = ptr[0]; + const int end = ptr[1]; + for (int pos = blockIdx.x * blockDim.x + begin + threadIdx.x; pos < end; + pos += blockDim.x * gridDim.x) { + const Dtype valA = A[pos] * alpha; + const int offset_part = indices[pos] * ldc1; + for (int colC = blockIdx.y * blockDim.y + threadIdx.y; colC < N; + colC += blockDim.y * gridDim.y) { + const int C_offset = offset_part + colC * ldc2; + C[C_offset] = C[C_offset] + B[colC * ldb] * valA; + } + } +} + +// C = alpha A * B^T + C where A and B are vectors. +// A is a sprase vector and B is a dense vector +template +__device__ void caffe_gpu_csr_rank1_update_kernel(const int M, const int N, + const Dtype alpha, + const Dtype* A, + const int* indices, + const int* ptr, + const Dtype* B, int ldb, + Dtype* C, + const CBLAS_ORDER orderC) { + if (orderC == CblasRowMajor) { + caffe_gpu_csr_rank1_update_kernel_core(M, N, alpha, A, indices, ptr, B, ldb, + C, N, 1); + } else { + caffe_gpu_csr_rank1_update_kernel_core(M, N, alpha, A, indices, ptr, B, ldb, + C, 1, M); + } +} + +template +__global__ void caffe_gpu_csr_rank1_update_kernel_multi( + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const Dtype alpha, const Dtype* A, const int* indices, const int* ptr, + const Dtype* B, int ldb, Dtype* C, const CBLAS_ORDER orderC) { + if (TransB == CblasNoTrans) { + for (int i = 0; i < K; i++) { + caffe_gpu_csr_rank1_update_kernel(M, N, alpha, A, indices, ptr + i, + B + (N * i), 1, C, orderC); + } + } else { + for (int i = 0; i < K; i++) { + caffe_gpu_csr_rank1_update_kernel(M, N, alpha, A, indices, ptr + i, B + i, + K, C, orderC); + } + } +} + +template<> +void caffe_gpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, + const int N, const int K, const float alpha, + int nzz, const float* A, const int* indices, + const int* ptr, const float* B, const float beta, + float* C, const CBLAS_ORDER orderC) { + if (TransA == CblasNoTrans) { + dim3 grids(M, N); + dim3 threads(THREADS_PER_BLOCK_CSR, 1); + caffe_gpu_csr_gemm_kernel<< >>(TransB, M, N, K, + alpha, nzz, A, indices, ptr, B, beta, C, orderC); + } else { + // scale C by beta + if (beta != 1.0) { + CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle() , M * N, &beta, C, 1)); + } + const int average_nzz_per_row = nzz/K+1; + dim3 grids((average_nzz_per_row+64-1)/64, N); + dim3 threads(64, 1); + caffe_gpu_csr_rank1_update_kernel_multi<< >>(TransB, + M, N, K, + alpha, A, indices, ptr , B, 1, C, orderC); + } +} + +template<> +void caffe_gpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, + const int N, const int K, const double alpha, + int nzz, const double* A, const int* indices, + const int* ptr, const double* B, + const double beta, double* C, + const CBLAS_ORDER orderC) { + if (TransA == CblasNoTrans) { + dim3 grids(M, N); + dim3 threads(THREADS_PER_BLOCK_CSR, 1); + caffe_gpu_csr_gemm_kernel<< >> (TransB, M, N, K, + alpha, nzz, A, indices, ptr, B, beta, C, orderC); + } else { + // scale C by beta + if (beta != 1.0) { + CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle() , M * N, &beta, C, 1)); + } + const int average_nzz_per_row = nzz/K+1; + dim3 grids((average_nzz_per_row+64-1)/64, N); + dim3 threads(64, 1); + caffe_gpu_csr_rank1_update_kernel_multi<< >>(TransB, + M, N, K, + alpha, A, indices, ptr , B, 1, C, orderC); + } +} + +/* Other implementation using cusparse that is very slow at least using it like this +template <> +void caffe_gpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const float alpha, int nzz, const float* A, const int* indices, const int* ptr, const float* B, const float beta, + float* C, const CBLAS_ORDER orderC) { + + //std::cout << "M: " << M << " N: " << N << " K: " << K << " NZZ: " << nzz <<"\n" ; + + int ldb = (TransB == CblasNoTrans) ? N : K; + cusparseOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + cusparseOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + + float* Bt; + int ldb_t; + + bool reuiqre_transpose_B = (cuTransA == CUSPARSE_OPERATION_TRANSPOSE) && (cuTransB == CUSPARSE_OPERATION_TRANSPOSE); + if (reuiqre_transpose_B){ + //we need to transpose B because this operation is not supported by cusparse (god knows why) + ldb_t = K; + const float zero = 0.0; + const float one = 1.0; + CUDA_CHECK(cudaMalloc((void**)&Bt, sizeof(float)*K*N)); + CUBLAS_CHECK(cublasSgeam(Caffe::cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_T, K, N, &one, B, ldb, &zero, B, ldb, Bt, ldb_t)); + } + + int msparse = (TransA == CblasNoTrans) ? M : K; + int ksparse = (TransA == CblasNoTrans) ? K : M; + if (orderC == CblasRowMajor){ + float* Ct; + CUDA_CHECK(cudaMalloc((void**)&Ct, sizeof(float)*M*N)); + const float zero = 0.0; + const float one = 1.0; + if (reuiqre_transpose_B){ + CUSPARSE_CHECK(cusparseScsrmm2(Caffe::cusparse_handle(), cuTransA, CUSPARSE_OPERATION_NON_TRANSPOSE, msparse, N, ksparse,nzz, &alpha, Caffe::cusparse_mat_descr(), A, ptr, indices, Bt, ldb_t, &zero, Ct, M)); + CUDA_CHECK(cudaFree(Bt)); + }else{ + CUSPARSE_CHECK(cusparseScsrmm2(Caffe::cusparse_handle(), cuTransA, cuTransB, msparse, N, ksparse,nzz, &alpha, Caffe::cusparse_mat_descr(), A, ptr, indices, B, ldb, &zero, Ct, M)); + } + CUBLAS_CHECK(cublasSgeam(Caffe::cublas_handle(), CUBLAS_OP_T , CUBLAS_OP_N, N, M, &one, Ct, M, &beta, C, N, C, N)); + CUDA_CHECK(cudaFree(Ct)); + }else{ + //this is the default of CUSPARSE by the Matrix B is by default rowmajor + if (reuiqre_transpose_B){ + CUSPARSE_CHECK(cusparseScsrmm2(Caffe::cusparse_handle(), cuTransA, CUSPARSE_OPERATION_NON_TRANSPOSE, msparse, N, ksparse,nzz, &alpha, Caffe::cusparse_mat_descr(), A, ptr, indices, Bt, ldb_t, &beta, C, M)); + CUDA_CHECK(cudaFree(Bt)); + }else{ + CUSPARSE_CHECK(cusparseScsrmm2(Caffe::cusparse_handle(), cuTransA, cuTransB, msparse, N, ksparse,nzz, &alpha, Caffe::cusparse_mat_descr(), A, ptr, indices, B, ldb, &beta, C, M)); + } + } +} + +template <> +void caffe_gpu_csr_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const double alpha, int nzz, const double* A, const int* indices, const int* ptr, const double* B, const double beta, + double* C, const CBLAS_ORDER orderC) { + + //std::cout << "M: " << M << "N: " << N << "K: " << K << "NZZ: " << nzz ; + int ldb = (TransB == CblasNoTrans) ? N : K; + cusparseOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + cusparseOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + + double* Bt; + int ldb_t; + bool reuiqre_transpose_B = (cuTransA == CUSPARSE_OPERATION_TRANSPOSE) && (cuTransB == CUSPARSE_OPERATION_TRANSPOSE); + if (reuiqre_transpose_B){ + //we need to transpose B because this operation is not supported by cusparse (god knows why) + ldb_t = K; + const double zero = 0.0; + const double one = 1.0; + CUDA_CHECK(cudaMalloc((void**)&Bt, sizeof(double)*K*N)); + CUBLAS_CHECK(cublasDgeam(Caffe::cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_T, K, N, &one, B, ldb, &zero, B, ldb, Bt, ldb_t)); + } + + int msparse = (TransA == CblasNoTrans) ? M : K; + int ksparse = (TransA == CblasNoTrans) ? K : M; + if (orderC == CblasRowMajor){ + double* Ct; + CUDA_CHECK(cudaMalloc((void**)&Ct, sizeof(double)*M*N)); + const double zero = 0.0; + const double one = 1.0; + if (reuiqre_transpose_B){ + CUSPARSE_CHECK(cusparseDcsrmm2(Caffe::cusparse_handle(), cuTransA, CUSPARSE_OPERATION_NON_TRANSPOSE, msparse, N, ksparse,nzz, &alpha, Caffe::cusparse_mat_descr(), A, ptr, indices, Bt, ldb_t, &zero, Ct, M)); + CUDA_CHECK(cudaFree(Bt)); + }else{ + CUSPARSE_CHECK(cusparseDcsrmm2(Caffe::cusparse_handle(), cuTransA, cuTransB, msparse, N, ksparse,nzz, &alpha, Caffe::cusparse_mat_descr(), A, ptr, indices, B, ldb, &zero, Ct, M)); + } + CUBLAS_CHECK(cublasDgeam(Caffe::cublas_handle(), CUBLAS_OP_T , CUBLAS_OP_N, N, M, &one, Ct, M, &beta, C, N, C, N)); + CUDA_CHECK(cudaFree(Ct)); + }else{ + //this is the default of CUSPARSE by the Matrix B is by default rowmajor + if (reuiqre_transpose_B){ + CUSPARSE_CHECK(cusparseDcsrmm2(Caffe::cusparse_handle(), cuTransA, CUSPARSE_OPERATION_NON_TRANSPOSE, msparse, N, ksparse,nzz, &alpha, Caffe::cusparse_mat_descr(), A, ptr, indices, Bt, ldb_t, &beta, C, M)); + CUDA_CHECK(cudaFree(Bt)); + }else{ + CUSPARSE_CHECK(cusparseDcsrmm2(Caffe::cusparse_handle(), cuTransA, cuTransB, msparse, N, ksparse,nzz, &alpha, Caffe::cusparse_mat_descr(), A, ptr, indices, B, ldb, &beta, C, M)); + } + } +} + +*/ + } // namespace caffe