diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 79336d431eca50..82ca7b5f224ae1 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -1842,7 +1842,7 @@ bool SlotPaddleBoxDataFeed::Start() { this->finish_start_ = true; #if defined(PADDLE_WITH_CUDA) && defined(_LINUX) CHECK(paddle::platform::is_gpu_place(this->place_)); - pack_ = new MiniBatchGpuPack(this->GetPlace(), used_slots_info_); + pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_); #endif return true; } @@ -1852,23 +1852,45 @@ int SlotPaddleBoxDataFeed::Next() { if (offset_index_ >= static_cast(batch_offsets_.size())) { return 0; } - auto& batch = batch_offsets_[offset_index_++]; if (enable_pv_merge_ && phase == 1) { // join phase : output_pv_channel to consume_pv_channel this->batch_size_ = batch.second; if (this->batch_size_ != 0) { + batch_timer_.Resume(); PutToFeedPvVec(&pv_ins_[batch.first], this->batch_size_); + batch_timer_.Pause(); } else { VLOG(3) << "finish reading, batch size zero, thread_id=" << thread_id_; } return this->batch_size_; } else { this->batch_size_ = batch.second; + batch_timer_.Resume(); PutToFeedSlotVec(&records_[batch.first], this->batch_size_); + batch_timer_.Pause(); return this->batch_size_; } } +bool SlotPaddleBoxDataFeed::EnablePvMerge(void) { + return (enable_pv_merge_ && GetCurrentPhase() == 1); +} +int SlotPaddleBoxDataFeed::GetPackInstance(SlotRecord** ins) { + if (offset_index_ >= static_cast(batch_offsets_.size())) { + return 0; + } + auto& batch = batch_offsets_[offset_index_]; + *ins = &records_[batch.first]; + return batch.second; +} +int SlotPaddleBoxDataFeed::GetPackPvInstance(SlotPvInstance** pv_ins) { + if (offset_index_ >= static_cast(batch_offsets_.size())) { + return 0; + } + auto& batch = batch_offsets_[offset_index_]; + *pv_ins = &pv_ins_[batch.first]; + return batch.second; +} void SlotPaddleBoxDataFeed::AssignFeedVar(const Scope& scope) { CheckInit(); for (int i = 0; i < use_slot_size_; ++i) { @@ -1886,8 +1908,10 @@ void SlotPaddleBoxDataFeed::PutToFeedPvVec(const SlotPvInstance* pvs, int num) { paddle::platform::SetDeviceId( boost::get(place_).GetDeviceId()); pack_->pack_pvinstance(pvs, num); - GetRankOffsetGPU(); - BuildSlotBatchGPU(); + int ins_num = pack_->ins_num(); + int pv_num = pack_->pv_num(); + GetRankOffsetGPU(pv_num, ins_num); + BuildSlotBatchGPU(ins_num); #else int ins_number = 0; std::vector ins_vec; @@ -1971,7 +1995,7 @@ void SlotPaddleBoxDataFeed::PutToFeedSlotVec(const SlotRecord* ins_vec, paddle::platform::SetDeviceId( boost::get(place_).GetDeviceId()); pack_->pack_instance(ins_vec, num); - BuildSlotBatchGPU(); + BuildSlotBatchGPU(pack_->ins_num()); #else for (int j = 0; j < use_slot_size_; ++j) { auto& feed = feed_vec_[j]; @@ -2056,32 +2080,47 @@ void SlotPaddleBoxDataFeed::PutToFeedSlotVec(const SlotRecord* ins_vec, // LOG(WARNING) << "[" << name << "]" << ostream.str(); //} -void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(void) { +void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(const int ins_num) { #if defined(PADDLE_WITH_CUDA) && defined(_LINUX) - int ins_num = pack_->ins_num(); + fill_timer_.Resume(); + int offset_cols_size = (ins_num + 1); - slot_value_offsets_.resize(use_slot_size_ * offset_cols_size); - auto gpu_slot_offsets = memory::AllocShared( - this->GetPlace(), (use_slot_size_ * offset_cols_size) * sizeof(size_t)); + size_t slot_total_bytes = + (use_slot_size_ * offset_cols_size) * sizeof(size_t); + if (gpu_slot_offsets_ == nullptr) { + gpu_slot_offsets_ = memory::AllocShared(this->GetPlace(), slot_total_bytes); + } else if (gpu_slot_offsets_->size() < slot_total_bytes) { + auto buf = memory::AllocShared(this->GetPlace(), slot_total_bytes); + gpu_slot_offsets_.swap(buf); + buf = nullptr; + } auto& value = pack_->value(); const UsedSlotGpuType* used_slot_gpu_types = static_cast(pack_->get_gpu_slots()); - FillSlotValueOffset(&slot_value_offsets_, ins_num, use_slot_size_, - reinterpret_cast(gpu_slot_offsets->ptr()), + FillSlotValueOffset(ins_num, use_slot_size_, + reinterpret_cast(gpu_slot_offsets_->ptr()), value.d_uint64_offset.data(), uint64_use_slot_size_, value.d_float_offset.data(), float_use_slot_size_, used_slot_gpu_types); + fill_timer_.Pause(); + size_t* d_slot_offsets = reinterpret_cast(gpu_slot_offsets_->ptr()); + + offset_timer_.Resume(); + thread_local std::vector offsets; + offsets.resize(offset_cols_size); + thread_local HostBuffer h_tensor_ptrs; + h_tensor_ptrs.resize(use_slot_size_); - std::vector offsets(offset_cols_size, 0); - std::vector h_tensor_ptrs(use_slot_size_, nullptr); for (int j = 0; j < use_slot_size_; ++j) { auto& feed = feed_vec_[j]; if (feed == nullptr) { + h_tensor_ptrs[j] = nullptr; continue; } - memcpy(offsets.data(), &slot_value_offsets_[j * offset_cols_size], - offset_cols_size * sizeof(size_t)); + + cudaMemcpy(offsets.data(), &d_slot_offsets[j * offset_cols_size], + offset_cols_size * sizeof(size_t), cudaMemcpyDeviceToHost); int total_instance = offsets.back(); CHECK(total_instance >= 0) << "slot idx:" << j << ", total instance:" << total_instance; @@ -2094,9 +2133,7 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(void) { h_tensor_ptrs[j] = feed->mutable_data({total_instance, 1}, this->place_); } - - LoD data_lod{offsets}; - feed_vec_[j]->set_lod(data_lod); + feed->set_lod({offsets}); if (info.dense) { if (info.inductive_shape_index != -1) { @@ -2106,17 +2143,19 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(void) { feed->Resize(framework::make_ddim(info.local_shape)); } } + offset_timer_.Pause(); - auto buf = - memory::AllocShared(this->GetPlace(), use_slot_size_ * sizeof(void*)); - void** dest_gpu_p = reinterpret_cast(buf->ptr()); - // fprintf(stderr, "after create tensor\n"); - // pack_->copy2tensor_ptr(h_tensor_ptrs); + trans_timer_.Resume(); + if (slot_buf_ptr_ == nullptr) { + slot_buf_ptr_ = + memory::AllocShared(this->GetPlace(), use_slot_size_ * sizeof(void*)); + } + void** dest_gpu_p = reinterpret_cast(slot_buf_ptr_->ptr()); cudaMemcpy(dest_gpu_p, h_tensor_ptrs.data(), use_slot_size_ * sizeof(void*), cudaMemcpyHostToDevice); CopyForTensor(ins_num, use_slot_size_, dest_gpu_p, - (const size_t*)gpu_slot_offsets->ptr(), + (const size_t*)gpu_slot_offsets_->ptr(), (const uint64_t*)value.d_uint64_keys.data(), (const int*)value.d_uint64_offset.data(), (const int*)value.d_uint64_lens.data(), uint64_use_slot_size_, @@ -2124,6 +2163,7 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(void) { (const int*)value.d_float_offset.data(), (const int*)value.d_float_lens.data(), float_use_slot_size_, used_slot_gpu_types); + trans_timer_.Pause(); #endif } int SlotPaddleBoxDataFeed::GetCurrentPhase() { @@ -2134,16 +2174,15 @@ int SlotPaddleBoxDataFeed::GetCurrentPhase() { return box_ptr->Phase(); } } -void SlotPaddleBoxDataFeed::GetRankOffsetGPU(void) { +void SlotPaddleBoxDataFeed::GetRankOffsetGPU(const int pv_num, + const int ins_num) { #if defined(PADDLE_WITH_CUDA) && defined(_LINUX) int max_rank = 3; // the value is setting int col = max_rank * 2 + 1; - int ins_num = pack_->ins_num(); - auto& value = pack_->value(); int* tensor_ptr = rank_offset_->mutable_data({ins_num, col}, this->place_); - CopyRankOffset(tensor_ptr, ins_num, pack_->pv_num(), max_rank, + CopyRankOffset(tensor_ptr, ins_num, pv_num, max_rank, (const int*)value.d_rank.data(), (const int*)value.d_cmatch.data(), (const int*)value.d_ad_offset.data(), col); @@ -2631,9 +2670,24 @@ bool SlotPaddleBoxDataFeed::ParseOneInstance(const std::string& line, ////////////////////////////// pack //////////////////////////////////// #if defined(PADDLE_WITH_CUDA) && defined(_LINUX) -SlotPaddleBoxDataFeed::MiniBatchGpuPack::MiniBatchGpuPack( - const paddle::platform::Place& place, - const std::vector& infos) { +static void SetCPUAffinity(int tid) { + std::vector& cores = boxps::get_train_cores(); + if (cores.empty()) { + VLOG(0) << "not found binding read ins thread cores"; + return; + } + + size_t core_num = cores.size() / 2; + if (core_num < 8) { + return; + } + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(cores[core_num + (tid % core_num)], &mask); + pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask); +} +MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place, + const std::vector& infos) { place_ = place; // paddle::platform::SetDeviceId(boost::get(place).GetDeviceId()); // paddle::platform::CUDADeviceContext* context = @@ -2661,13 +2715,26 @@ SlotPaddleBoxDataFeed::MiniBatchGpuPack::MiniBatchGpuPack( ++used_float_num_; } } - copy_host2device(&gpu_slots_, gpu_used_slots_); + copy_host2device(&gpu_slots_, gpu_used_slots_.data(), gpu_used_slots_.size()); } -SlotPaddleBoxDataFeed::MiniBatchGpuPack::~MiniBatchGpuPack() {} +MiniBatchGpuPack::~MiniBatchGpuPack() {} + +void MiniBatchGpuPack::reset(const paddle::platform::Place& place) { + place_ = place; + stream_ = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + boost::get(place))) + ->stream(); + ins_num_ = 0; + pv_num_ = 0; + enable_pv_ = false; + + pack_timer_.Reset(); + trans_timer_.Reset(); +} -void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_pvinstance( - const SlotPvInstance* pv_ins, int num) { +void MiniBatchGpuPack::pack_pvinstance(const SlotPvInstance* pv_ins, int num) { pv_num_ = num; buf_.h_ad_offset.resize(num + 1); buf_.h_ad_offset[0] = 0; @@ -2689,8 +2756,7 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_pvinstance( pack_instance(&ins_vec_[0], ins_number); } -void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_all_data( - const SlotRecord* ins_vec, int num) { +void MiniBatchGpuPack::pack_all_data(const SlotRecord* ins_vec, int num) { int uint64_total_num = 0; int float_total_num = 0; @@ -2759,8 +2825,7 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_all_data( CHECK(float_total_num == static_cast(buf_.h_float_lens.back())) << "float value length error"; } -void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_uint64_data( - const SlotRecord* ins_vec, int num) { +void MiniBatchGpuPack::pack_uint64_data(const SlotRecord* ins_vec, int num) { int uint64_total_num = 0; buf_.h_float_lens.clear(); @@ -2809,8 +2874,7 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_uint64_data( CHECK(uint64_total_num == static_cast(buf_.h_uint64_lens.back())) << "uint64 value length error"; } -void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_float_data( - const SlotRecord* ins_vec, int num) { +void MiniBatchGpuPack::pack_float_data(const SlotRecord* ins_vec, int num) { int float_total_num = 0; buf_.h_uint64_lens.clear(); @@ -2859,8 +2923,8 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_float_data( << "float value length error"; } -void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_instance( - const SlotRecord* ins_vec, int num) { +void MiniBatchGpuPack::pack_instance(const SlotRecord* ins_vec, int num) { + pack_timer_.Resume(); ins_num_ = num; CHECK(used_uint64_num_ > 0 || used_float_num_ > 0); // uint64 and float @@ -2871,11 +2935,13 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_instance( } else { // only float pack_float_data(ins_vec, num); } + pack_timer_.Pause(); // to gpu transfer_to_gpu(); } -void SlotPaddleBoxDataFeed::MiniBatchGpuPack::transfer_to_gpu(void) { +void MiniBatchGpuPack::transfer_to_gpu(void) { + trans_timer_.Resume(); if (enable_pv_) { copy_host2device(&value_.d_ad_offset, buf_.h_ad_offset); copy_host2device(&value_.d_rank, buf_.h_rank); @@ -2889,6 +2955,7 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::transfer_to_gpu(void) { copy_host2device(&value_.d_float_keys, buf_.h_float_keys); copy_host2device(&value_.d_float_offset, buf_.h_float_offset); cudaStreamSynchronize(stream_); + trans_timer_.Pause(); } #endif diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 73fdf7f000792e..3a97dffe98678f 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -25,8 +25,8 @@ namespace framework { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) -__global__ void CopyForTensorKernel(FeatureItem* src, void** dest, - size_t* offset, char* type, +__global__ void CopyForTensorKernel(FeatureItem *src, void **dest, + size_t *offset, char *type, size_t total_size, size_t row_size, size_t col_size) { CUDA_KERNEL_LOOP(i, row_size * col_size) { @@ -42,14 +42,14 @@ __global__ void CopyForTensorKernel(FeatureItem* src, void** dest, right = offset[row_id * (col_size + 1) + col_id + 1] - offset[(row_id - 1) * (col_size + 1) + col_id + 1]; } - - uint64_t* up = NULL; - float* fp = NULL; + + uint64_t *up = NULL; + float *fp = NULL; if (type[row_id] == 'f') { - fp = reinterpret_cast(dest[row_id]); + fp = reinterpret_cast(dest[row_id]); } else { - up = reinterpret_cast( - *(reinterpret_cast(dest) + row_id)); + up = reinterpret_cast( + *(reinterpret_cast(dest) + row_id)); } size_t begin = offset[row_id * (col_size + 1) + col_id + 1] + offset[(row_size - 1) * (col_size + 1) + col_id] - @@ -72,10 +72,10 @@ __global__ void CopyForTensorKernel(FeatureItem* src, void** dest, } void MultiSlotInMemoryDataFeed::CopyForTensor( - const paddle::platform::Place& place, FeatureItem* src, void** dest, - size_t* offset, char* type, size_t total_size, size_t row_size, + const paddle::platform::Place &place, FeatureItem *src, void **dest, + size_t *offset, char *type, size_t total_size, size_t row_size, size_t col_size) { - auto stream = dynamic_cast( + auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( boost::get(place))) ->stream(); @@ -91,7 +91,7 @@ const int CUDA_NUM_THREADS = 512; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } - +// fill slot values __global__ void FillSlotValueOffsetKernel( const int ins_num, const int used_slot_num, size_t *slot_value_offsets, const int *uint64_offsets, const int uint64_slot_size, @@ -127,8 +127,7 @@ __global__ void FillSlotValueOffsetKernel( } void SlotPaddleBoxDataFeed::FillSlotValueOffset( - std::vector *cpu_slot_value_offsets, const int ins_num, - const int used_slot_num, size_t *slot_value_offsets, + const int ins_num, const int used_slot_num, size_t *slot_value_offsets, const int *uint64_offsets, const int uint64_slot_size, const int *float_offsets, const int float_slot_size, const UsedSlotGpuType *used_slots) { @@ -140,9 +139,6 @@ void SlotPaddleBoxDataFeed::FillSlotValueOffset( stream>>>( ins_num, used_slot_num, slot_value_offsets, uint64_offsets, uint64_slot_size, float_offsets, float_slot_size, used_slots); - cudaMemcpyAsync(cpu_slot_value_offsets->data(), slot_value_offsets, - cpu_slot_value_offsets->size() * sizeof(size_t), - cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); } diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 2782cd2088eda0..1d6f8163cc8aea 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -19,6 +19,7 @@ limitations under the License. */ #define _LINUX #endif +#include #include #include // NOLINT #include @@ -30,7 +31,6 @@ limitations under the License. */ #include #include #include - #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/channel.h" @@ -47,6 +47,8 @@ limitations under the License. */ USE_INT_STAT(STAT_total_feasign_num_in_mem); USE_INT_STAT(STAT_slot_pool_size); DECLARE_int32(padbox_record_pool_max_size); +DECLARE_int32(padbox_slotpool_thread_num); + namespace paddle { namespace framework { @@ -874,11 +876,15 @@ class SlotObjPool { public: SlotObjPool() : max_capacity_(FLAGS_padbox_record_pool_max_size) { ins_chan_ = MakeChannel(); - thread_ = std::thread([this]() { run(); }); + for (int i = 0; i < FLAGS_padbox_slotpool_thread_num; ++i) { + threads_.push_back(std::thread([this]() { run(); })); + } } ~SlotObjPool() { ins_chan_->Close(); - thread_.join(); + for (auto& t : threads_) { + t.join(); + } } void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; } void get(std::vector* output, int n) { @@ -956,7 +962,7 @@ class SlotObjPool { private: size_t max_capacity_; Channel ins_chan_; - std::thread thread_; + std::vector threads_; std::mutex mutex_; SlotObjAllocator alloc_; }; @@ -981,7 +987,16 @@ class ISlotParser { const std::string& line, std::function&, int)> GetInsFunc) = 0; }; - +struct UsedSlotInfo { + int idx; + int slot_value_idx; + std::string slot; + std::string type; + bool dense; + std::vector local_shape; + int total_dims_without_inductive; + int inductive_shape_index; +}; #if defined(PADDLE_WITH_CUDA) && defined(_LINUX) struct UsedSlotGpuType { int is_uint64_value; @@ -1018,107 +1033,192 @@ struct CudaBuffer { malloc(size); } }; -#endif +template +struct HostBuffer { + T* host_buffer; + size_t buf_size; + size_t data_len; -class SlotPaddleBoxDataFeed : public DataFeed { - struct UsedSlotInfo { - int idx; - int slot_value_idx; - std::string slot; - std::string type; - bool dense; - std::vector local_shape; - int total_dims_without_inductive; - int inductive_shape_index; - }; -#if defined(PADDLE_WITH_CUDA) && defined(_LINUX) - struct BatchCPUValue { - std::vector h_uint64_lens; - std::vector h_uint64_keys; - std::vector h_uint64_offset; - - std::vector h_float_lens; - std::vector h_float_keys; - std::vector h_float_offset; - - std::vector h_rank; - std::vector h_cmatch; - std::vector h_ad_offset; - }; + HostBuffer() { + host_buffer = NULL; + buf_size = 0; + data_len = 0; + } + ~HostBuffer() { free(); } + + T* data() { return host_buffer; } + const T* data() const { return host_buffer; } + size_t size() const { return data_len; } + void clear() { free(); } + T& back() { return host_buffer[data_len - 1]; } - struct BatchGPUValue { - CudaBuffer d_uint64_lens; - CudaBuffer d_uint64_keys; - CudaBuffer d_uint64_offset; + T& operator[](size_t i) { return host_buffer[i]; } + const T& operator[](size_t i) const { return host_buffer[i]; } + void malloc(size_t len) { + buf_size = len; + cudaHostAlloc(reinterpret_cast(&host_buffer), buf_size * sizeof(T), + cudaHostAllocDefault); + } + void free() { + if (host_buffer != NULL) { + cudaFreeHost(host_buffer); + host_buffer = NULL; + } + buf_size = 0; + } + void resize(size_t size) { + if (size <= buf_size) { + data_len = size; + return; + } + data_len = size; + free(); + malloc(size); + } +}; - CudaBuffer d_float_lens; - CudaBuffer d_float_keys; - CudaBuffer d_float_offset; +struct BatchCPUValue { + HostBuffer h_uint64_lens; + HostBuffer h_uint64_keys; + HostBuffer h_uint64_offset; - CudaBuffer d_rank; - CudaBuffer d_cmatch; - CudaBuffer d_ad_offset; - }; + HostBuffer h_float_lens; + HostBuffer h_float_keys; + HostBuffer h_float_offset; - class MiniBatchGpuPack { - public: - MiniBatchGpuPack(const paddle::platform::Place& place, - const std::vector& infos); - ~MiniBatchGpuPack(); - void pack_pvinstance(const SlotPvInstance* pv_ins, int num); - void pack_instance(const SlotRecord* ins_vec, int num); - int ins_num() { return ins_num_; } - int pv_num() { return pv_num_; } - BatchGPUValue& value() { return value_; } - BatchCPUValue& cpu_value() { return buf_; } - UsedSlotGpuType* get_gpu_slots(void) { - return reinterpret_cast(gpu_slots_.data()); + HostBuffer h_rank; + HostBuffer h_cmatch; + HostBuffer h_ad_offset; +}; + +struct BatchGPUValue { + CudaBuffer d_uint64_lens; + CudaBuffer d_uint64_keys; + CudaBuffer d_uint64_offset; + + CudaBuffer d_float_lens; + CudaBuffer d_float_keys; + CudaBuffer d_float_offset; + + CudaBuffer d_rank; + CudaBuffer d_cmatch; + CudaBuffer d_ad_offset; +}; + +class SlotPaddleBoxDataFeed; +class MiniBatchGpuPack { + public: + MiniBatchGpuPack(const paddle::platform::Place& place, + const std::vector& infos); + ~MiniBatchGpuPack(); + void reset(const paddle::platform::Place& place); + void pack_pvinstance(const SlotPvInstance* pv_ins, int num); + void pack_instance(const SlotRecord* ins_vec, int num); + int ins_num() { return ins_num_; } + int pv_num() { return pv_num_; } + BatchGPUValue& value() { return value_; } + BatchCPUValue& cpu_value() { return buf_; } + UsedSlotGpuType* get_gpu_slots(void) { + return reinterpret_cast(gpu_slots_.data()); + } + SlotRecord* get_records(void) { return &ins_vec_[0]; } + double pack_time_span(void) { return pack_timer_.ElapsedSec(); } + double trans_time_span(void) { return trans_timer_.ElapsedSec(); } + + private: + void transfer_to_gpu(void); + void pack_all_data(const SlotRecord* ins_vec, int num); + void pack_uint64_data(const SlotRecord* ins_vec, int num); + void pack_float_data(const SlotRecord* ins_vec, int num); + + public: + template + void copy_host2device(CudaBuffer* buf, const T* val, size_t size) { + if (size == 0) { + return; } - SlotRecord* get_records(void) { return &ins_vec_[0]; } - - private: - void transfer_to_gpu(void); - void pack_all_data(const SlotRecord* ins_vec, int num); - void pack_uint64_data(const SlotRecord* ins_vec, int num); - void pack_float_data(const SlotRecord* ins_vec, int num); - - public: - template - void copy_host2device(CudaBuffer* buf, const std::vector& val) { - size_t size = val.size(); - if (size == 0) { - return; + buf->resize(size); + cudaMemcpyAsync(buf->data(), val, size * sizeof(T), cudaMemcpyHostToDevice, + stream_); + } + template + void copy_host2device(CudaBuffer* buf, const HostBuffer& val) { + copy_host2device(buf, val.data(), val.size()); + } + + private: + paddle::platform::Place place_; + cudaStream_t stream_; + BatchGPUValue value_; + BatchCPUValue buf_; + int ins_num_ = 0; + int pv_num_ = 0; + + bool enable_pv_ = false; + int used_float_num_ = 0; + int used_uint64_num_ = 0; + int used_slot_size_ = 0; + + CudaBuffer gpu_slots_; + std::vector gpu_used_slots_; + std::vector ins_vec_; + + platform::Timer pack_timer_; + platform::Timer trans_timer_; +}; +class MiniBatchGpuPackMgr { + static const int MAX_DEIVCE_NUM = 16; + + public: + MiniBatchGpuPackMgr() { + for (int i = 0; i < MAX_DEIVCE_NUM; ++i) { + pack_list_[i] = nullptr; + } + } + ~MiniBatchGpuPackMgr() { + for (int i = 0; i < MAX_DEIVCE_NUM; ++i) { + if (pack_list_[i] == nullptr) { + continue; } - buf->resize(size); - cudaMemcpyAsync(buf->data(), val.data(), size * sizeof(T), - cudaMemcpyHostToDevice, stream_); + delete pack_list_[i]; + pack_list_[i] = nullptr; } + } + // one device one thread + MiniBatchGpuPack* get(const paddle::platform::Place& place, + const std::vector& infos) { + int device_id = boost::get(place).GetDeviceId(); + if (pack_list_[device_id] == nullptr) { + pack_list_[device_id] = new MiniBatchGpuPack(place, infos); + } else { + pack_list_[device_id]->reset(place); + } + return pack_list_[device_id]; + } - private: - paddle::platform::Place place_; - cudaStream_t stream_; - BatchGPUValue value_; - BatchCPUValue buf_; - int ins_num_ = 0; - int pv_num_ = 0; - - bool enable_pv_ = false; - int used_float_num_ = 0; - int used_uint64_num_ = 0; - int used_slot_size_ = 0; - - CudaBuffer gpu_slots_; - std::vector gpu_used_slots_; - std::vector ins_vec_; - }; + private: + MiniBatchGpuPack* pack_list_[MAX_DEIVCE_NUM]; +}; +// global mgr +inline MiniBatchGpuPackMgr& BatchGpuPackMgr() { + static MiniBatchGpuPackMgr mgr; + return mgr; +} #endif +class SlotPaddleBoxDataFeed : public DataFeed { public: SlotPaddleBoxDataFeed() { finish_start_ = false; } virtual ~SlotPaddleBoxDataFeed() { #if defined(PADDLE_WITH_CUDA) && defined(_LINUX) if (pack_ != nullptr) { - delete pack_; + LOG(WARNING) << "pack batch total time: " << batch_timer_.ElapsedSec() + << "[copy:" << pack_->trans_time_span() + << ",fill:" << fill_timer_.ElapsedSec() + << ",tensor:" << offset_timer_.ElapsedSec() + << ",trans:" << trans_timer_.ElapsedSec() + << "], batch cpu build mem: " << pack_->pack_time_span() + << "sec"; pack_ = nullptr; } #endif @@ -1157,6 +1257,10 @@ class SlotPaddleBoxDataFeed : public DataFeed { void GetUsedSlotIndex(std::vector* used_slot_index); // expand values void ExpandSlotRecord(SlotRecord* ins); + // pack + bool EnablePvMerge(void); + int GetPackInstance(SlotRecord** ins); + int GetPackPvInstance(SlotPvInstance** pv_ins); public: virtual void Init(const DataFeedDesc& data_feed_desc); @@ -1171,8 +1275,8 @@ class SlotPaddleBoxDataFeed : public DataFeed { void LoadIntoMemoryByLib(void); void PutToFeedPvVec(const SlotPvInstance* pvs, int num); void PutToFeedSlotVec(const SlotRecord* recs, int num); - void BuildSlotBatchGPU(void); - void GetRankOffsetGPU(void); + void BuildSlotBatchGPU(const int ins_num); + void GetRankOffsetGPU(const int pv_num, const int ins_num); void GetRankOffset(const SlotPvInstance* pv_vec, int pv_num, int ins_number); bool ParseOneInstance(const std::string& line, SlotRecord* rec); @@ -1181,8 +1285,7 @@ class SlotPaddleBoxDataFeed : public DataFeed { void CopyRankOffset(int* dest, const int ins_num, const int pv_num, const int max_rank, const int* ranks, const int* cmatchs, const int* ad_offsets, const int cols); - void FillSlotValueOffset(std::vector* cpu_slot_value_offsets, - const int ins_num, const int used_slot_num, + void FillSlotValueOffset(const int ins_num, const int used_slot_num, size_t* slot_value_offsets, const int* uint64_offsets, const int uint64_slot_size, const int* float_offsets, @@ -1229,8 +1332,15 @@ class SlotPaddleBoxDataFeed : public DataFeed { SlotRecord* records_ = nullptr; std::vector all_slots_info_; std::vector used_slots_info_; - std::vector slot_value_offsets_; std::string parser_so_path_; + std::shared_ptr gpu_slot_offsets_ = + nullptr; + std::shared_ptr slot_buf_ptr_ = + nullptr; + platform::Timer batch_timer_; + platform::Timer fill_timer_; + platform::Timer offset_timer_; + platform::Timer trans_timer_; }; template diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 951dcc0e9c07d0..7e868c30d4e403 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -1395,7 +1395,9 @@ class PadBoxSlotDataConsumer : public boxps::DataConsumer { : _dataset(dataset) { BoxWrapper::data_shuffle_->register_handler(this); } - virtual ~PadBoxSlotDataConsumer() {} + virtual ~PadBoxSlotDataConsumer() { + BoxWrapper::data_shuffle_->register_handler(nullptr); + } virtual void on_receive(const int client_id, const char* buff, int len) { _dataset->ReceiveSuffleData(client_id, buff, len); } @@ -1407,22 +1409,9 @@ class PadBoxSlotDataConsumer : public boxps::DataConsumer { PadBoxSlotDataset::PadBoxSlotDataset() { mpi_size_ = boxps::MPICluster::Ins().size(); mpi_rank_ = boxps::MPICluster::Ins().rank(); - - if (mpi_size_ > 1) { - finished_counter_ = mpi_size_; - mpi_flags_.assign(mpi_size_, 1); - VLOG(3) << "RegisterClientToClientMsgHandler"; - data_consumer_ = reinterpret_cast(new PadBoxSlotDataConsumer(this)); - VLOG(3) << "RegisterClientToClientMsgHandler done"; - } SlotRecordPool(); } -PadBoxSlotDataset::~PadBoxSlotDataset() { - if (data_consumer_ != nullptr) { - delete reinterpret_cast(data_consumer_); - data_consumer_ = nullptr; - } -} +PadBoxSlotDataset::~PadBoxSlotDataset() {} // create input channel and output channel void PadBoxSlotDataset::CreateChannel() { if (input_channel_ == nullptr) { @@ -1476,6 +1465,14 @@ void PadBoxSlotDataset::LoadIntoMemory() { std::vector load_threads; std::vector shuffle_threads; + if (mpi_size_ > 1) { + finished_counter_ = mpi_size_; + mpi_flags_.assign(mpi_size_, 1); + VLOG(3) << "RegisterClientToClientMsgHandler"; + data_consumer_ = reinterpret_cast(new PadBoxSlotDataConsumer(this)); + VLOG(3) << "RegisterClientToClientMsgHandler done"; + } + std::atomic ref(thread_num_); for (int64_t i = 0; i < thread_num_; ++i) { load_threads.push_back(std::thread([this, i, &ref]() { @@ -1505,6 +1502,10 @@ void PadBoxSlotDataset::LoadIntoMemory() { } } + if (data_consumer_ != nullptr) { + delete reinterpret_cast(data_consumer_); + data_consumer_ = nullptr; + } // shuffle_channel_->Clear(); // input_channel_->Clear(); diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index 7e7a3223fdca3a..0525a1c9807ad6 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #ifdef PADDLE_WITH_BOX_PS #include "paddle/fluid/framework/fleet/box_wrapper.h" #include @@ -31,7 +30,7 @@ cudaStream_t BoxWrapper::stream_list_[8]; int BoxWrapper::embedx_dim_ = 8; int BoxWrapper::expand_embed_dim_ = 0; -void BasicAucCalculator::add_data(double pred, int label) { +void BasicAucCalculator::add_unlock_data(double pred, int label) { PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet( "pred should be greater than 0")); PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet( @@ -49,59 +48,94 @@ void BasicAucCalculator::add_data(double pred, int label) { pos, _table_size, platform::errors::PreconditionNotMet( "pos must be less than table_size, but its value is: %d", pos)); - std::lock_guard lock(_table_mutex); _local_abserr += fabs(pred - label); _local_sqrerr += (pred - label) * (pred - label); _local_pred += pred; - _table[label][pos]++; + ++_table[label][pos]; } -void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label, int batch_size, - const paddle::platform::Place& place) { +void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label, + int batch_size, + const paddle::platform::Place& place) { if (_mode_collect_in_gpu) { cuda_add_data(place, d_label, d_pred, batch_size); } else { - std::vector h_pred; - std::vector h_label; + thread_local std::vector h_pred; + thread_local std::vector h_label; + h_pred.resize(batch_size); + h_label.resize(batch_size); + cudaMemcpy(h_pred.data(), d_pred, sizeof(float) * batch_size, + cudaMemcpyDeviceToHost); + cudaMemcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size, + cudaMemcpyDeviceToHost); + + std::lock_guard lock(_table_mutex); + for (int i = 0; i < batch_size; ++i) { + add_unlock_data(h_pred[i], h_label[i]); + } + } +} +// add mask data +void BasicAucCalculator::add_mask_data(const float* d_pred, const int64_t* d_label, + const int64_t *d_mask, + int batch_size, const paddle::platform::Place& place) { + if (_mode_collect_in_gpu) { + cuda_add_mask_data(place, d_label, d_pred, d_mask, batch_size); + } else { + thread_local std::vector h_pred; + thread_local std::vector h_label; + thread_local std::vector h_mask; h_pred.resize(batch_size); h_label.resize(batch_size); - cudaMemcpy(h_pred.data(), d_pred, sizeof(float) * batch_size, cudaMemcpyDeviceToHost); - cudaMemcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size, cudaMemcpyDeviceToHost); + h_mask.resize(batch_size); + + cudaMemcpy(h_pred.data(), d_pred, sizeof(float) * batch_size, + cudaMemcpyDeviceToHost); + cudaMemcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size, + cudaMemcpyDeviceToHost); + cudaMemcpy(h_mask.data(), d_mask, sizeof(int64_t) * batch_size, + cudaMemcpyDeviceToHost); + + std::lock_guard lock(_table_mutex); for (int i = 0; i < batch_size; ++i) { - add_data(h_pred[i], h_label[i]); + if (h_mask[i]) { + add_unlock_data(h_pred[i], h_label[i]); + } } } } void BasicAucCalculator::init(int table_size, int max_batch_size) { if (_mode_collect_in_gpu) { - PADDLE_ENFORCE_GE(max_batch_size, 0, platform::errors::PreconditionNotMet( - "max_batch_size should be greater than 0 in mode_collect_in_gpu")); - } - set_table_size(table_size); - set_max_batch_size(max_batch_size); - // init CPU memory - for (int i = 0; i < 2; i++) { - _table[i] = std::vector(); - } - // init GPU memory - if (_mode_collect_in_gpu) { - for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { - auto place = platform::CUDAPlace(i); - _d_positive.emplace_back( - memory::AllocShared(place, _table_size * sizeof(double))); - _d_negative.emplace_back( - memory::Alloc(place, _table_size * sizeof(double))); - _d_abserr.emplace_back( - memory::Alloc(place, _max_batch_size * sizeof(double))); - _d_sqrerr.emplace_back( - memory::Alloc(place, _max_batch_size * sizeof(double))); - _d_pred.emplace_back( - memory::Alloc(place, _max_batch_size * sizeof(double))); - } + PADDLE_ENFORCE_GE( + max_batch_size, 0, + platform::errors::PreconditionNotMet( + "max_batch_size should be greater than 0 in mode_collect_in_gpu")); + } + set_table_size(table_size); + set_max_batch_size(max_batch_size); + // init CPU memory + for (int i = 0; i < 2; i++) { + _table[i] = std::vector(); + } + // init GPU memory + if (_mode_collect_in_gpu) { + for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { + auto place = platform::CUDAPlace(i); + _d_positive.emplace_back( + memory::AllocShared(place, _table_size * sizeof(double))); + _d_negative.emplace_back( + memory::Alloc(place, _table_size * sizeof(double))); + _d_abserr.emplace_back( + memory::Alloc(place, _max_batch_size * sizeof(double))); + _d_sqrerr.emplace_back( + memory::Alloc(place, _max_batch_size * sizeof(double))); + _d_pred.emplace_back( + memory::Alloc(place, _max_batch_size * sizeof(double))); } - // reset - reset(); + } + // reset + reset(); } void BasicAucCalculator::reset() { @@ -128,10 +162,10 @@ void BasicAucCalculator::reset() { stream); cudaMemsetAsync(_d_negative[i]->ptr(), 0, sizeof(double) * _table_size, stream); - cudaMemsetAsync(_d_abserr[i]->ptr(), 0, - sizeof(double) * _max_batch_size, stream); - cudaMemsetAsync(_d_sqrerr[i]->ptr(), 0, - sizeof(double) * _max_batch_size, stream); + cudaMemsetAsync(_d_abserr[i]->ptr(), 0, sizeof(double) * _max_batch_size, + stream); + cudaMemsetAsync(_d_sqrerr[i]->ptr(), 0, sizeof(double) * _max_batch_size, + stream); cudaMemsetAsync(_d_pred[i]->ptr(), 0, sizeof(double) * _max_batch_size, stream); } @@ -141,51 +175,51 @@ void BasicAucCalculator::reset() { } void BasicAucCalculator::collect_data_nccl() { - // backup orginal device - int ori_device; - cudaGetDevice(&ori_device); - // transfer to CPU - platform::dynload::ncclGroupStart(); - // nccl allreduce sum - for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { - cudaSetDevice(i); - auto place = platform::CUDAPlace(i); - auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - auto comm = platform::NCCLCommContext::Instance().Get(0, place); - platform::dynload::ncclAllReduce( - _d_positive[i]->ptr(), _d_positive[i]->ptr(), _table_size, - ncclFloat64, ncclSum, comm->comm(), stream); - platform::dynload::ncclAllReduce( - _d_negative[i]->ptr(), _d_negative[i]->ptr(), _table_size, - ncclFloat64, ncclSum, comm->comm(), stream); - platform::dynload::ncclAllReduce(_d_abserr[i]->ptr(), _d_abserr[i]->ptr(), - _max_batch_size, ncclFloat64, ncclSum, - comm->comm(), stream); - platform::dynload::ncclAllReduce(_d_sqrerr[i]->ptr(), _d_sqrerr[i]->ptr(), - _max_batch_size, ncclFloat64, ncclSum, - comm->comm(), stream); - platform::dynload::ncclAllReduce(_d_pred[i]->ptr(), _d_pred[i]->ptr(), - _max_batch_size, ncclFloat64, ncclSum, - comm->comm(), stream); - } - platform::dynload::ncclGroupEnd(); - // sync - for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { - cudaSetDevice(i); - auto place = platform::CUDAPlace(i); - auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - cudaStreamSynchronize(stream); - } - // restore device - cudaSetDevice(ori_device); + // backup orginal device + int ori_device; + cudaGetDevice(&ori_device); + // transfer to CPU + platform::dynload::ncclGroupStart(); + // nccl allreduce sum + for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { + cudaSetDevice(i); + auto place = platform::CUDAPlace(i); + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + auto comm = platform::NCCLCommContext::Instance().Get(0, place); + platform::dynload::ncclAllReduce( + _d_positive[i]->ptr(), _d_positive[i]->ptr(), _table_size, ncclFloat64, + ncclSum, comm->comm(), stream); + platform::dynload::ncclAllReduce( + _d_negative[i]->ptr(), _d_negative[i]->ptr(), _table_size, ncclFloat64, + ncclSum, comm->comm(), stream); + platform::dynload::ncclAllReduce(_d_abserr[i]->ptr(), _d_abserr[i]->ptr(), + _max_batch_size, ncclFloat64, ncclSum, + comm->comm(), stream); + platform::dynload::ncclAllReduce(_d_sqrerr[i]->ptr(), _d_sqrerr[i]->ptr(), + _max_batch_size, ncclFloat64, ncclSum, + comm->comm(), stream); + platform::dynload::ncclAllReduce(_d_pred[i]->ptr(), _d_pred[i]->ptr(), + _max_batch_size, ncclFloat64, ncclSum, + comm->comm(), stream); + } + platform::dynload::ncclGroupEnd(); + // sync + for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { + cudaSetDevice(i); + auto place = platform::CUDAPlace(i); + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + cudaStreamSynchronize(stream); + } + // restore device + cudaSetDevice(ori_device); } void BasicAucCalculator::copy_data_d2h(int device) { - // backup orginal device + // backup orginal device int ori_device; cudaGetDevice(&ori_device); cudaSetDevice(device); @@ -201,11 +235,9 @@ void BasicAucCalculator::copy_data_d2h(int device) { h_sqrerr.resize(_max_batch_size); h_pred.resize(_max_batch_size); cudaMemcpyAsync(&_table[0][0], _d_negative[device]->ptr(), - _table_size * sizeof(double), cudaMemcpyDeviceToHost, - stream); + _table_size * sizeof(double), cudaMemcpyDeviceToHost, stream); cudaMemcpyAsync(&_table[1][0], _d_positive[device]->ptr(), - _table_size * sizeof(double), cudaMemcpyDeviceToHost, - stream); + _table_size * sizeof(double), cudaMemcpyDeviceToHost, stream); cudaMemcpyAsync(h_abserr.data(), _d_abserr[device]->ptr(), _max_batch_size * sizeof(double), cudaMemcpyDeviceToHost, stream); @@ -433,10 +465,8 @@ void BoxWrapper::EndFeedPass(boxps::PSAgentBase* agent) const { void BoxWrapper::BeginPass() const { int gpu_num = platform::GetCUDADeviceCount(); for (int i = 0; i < gpu_num; ++i) { - all_pull_timers_[i].Reset(); - boxps_pull_timers_[i].Reset(); - all_push_timers_[i].Reset(); - boxps_push_timers_[i].Reset(); + DeviceBoxData& dev = device_caches_[i]; + dev.ResetTimer(); } int ret = boxps_ptr_->BeginPass(); PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet( @@ -453,11 +483,12 @@ void BoxWrapper::EndPass(bool need_save_delta) const { ret, 0, platform::errors::PreconditionNotMet("EndPass failed in BoxPS.")); int gpu_num = platform::GetCUDADeviceCount(); for (int i = 0; i < gpu_num; ++i) { + auto& dev = device_caches_[i]; LOG(WARNING) << "gpu[" << i - << "] sparse pull span: " << all_pull_timers_[i].ElapsedSec() - << ", boxps span: " << boxps_pull_timers_[i].ElapsedSec() - << ", push span: " << all_push_timers_[i].ElapsedSec() - << ", boxps span:" << boxps_push_timers_[i].ElapsedSec(); + << "] sparse pull span: " << dev.all_pull_timer.ElapsedSec() + << ", boxps span: " << dev.boxps_pull_timer.ElapsedSec() + << ", push span: " << dev.all_push_timer.ElapsedSec() + << ", boxps span:" << dev.boxps_push_timer.ElapsedSec(); } } diff --git a/paddle/fluid/framework/fleet/box_wrapper.cu b/paddle/fluid/framework/fleet/box_wrapper.cu index aaa8f258d79157..62d2910d46db6f 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cu +++ b/paddle/fluid/framework/fleet/box_wrapper.cu @@ -31,20 +31,12 @@ template __global__ void PullCopy( float** dest, const boxps::FeatureValueGpu* src, - const int64_t* len, int hidden, int expand_dim, int slot_num, int total_len, - uint64_t** keys, int* total_dims) { + const int hidden, const int expand_dim, const int total_len, + uint64_t** keys, int* total_dims, const int64_t* slot_lens, + const int slot_num, const int* key2slot) { CUDA_KERNEL_LOOP(i, total_len) { - int low = 0; - int high = slot_num - 1; - while (low < high) { - int mid = (low + high) / 2; - if (i < len[mid]) - high = mid; - else - low = mid + 1; - } - int x = low; - int y = i - (x ? len[x - 1] : 0); + int x = key2slot[i]; + int y = i - (x ? slot_lens[x - 1] : 0); if (*(keys[x] + y) == 0) { *(dest[x] + y * hidden) = 0; *(dest[x] + y * hidden + 1) = 0; @@ -82,21 +74,29 @@ __global__ void PullCopy( } // end kernel loop } -__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys, - const int64_t* len, int slot_num, - int total_len) { +__global__ void FillKey2Slot(const int total_len, const int64_t* slot_lens, + const int slot_num, int* key2slots) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; int high = slot_num - 1; while (low < high) { int mid = (low + high) / 2; - if (i < len[mid]) + if (i < slot_lens[mid]) { high = mid; - else + } else { low = mid + 1; + } } - int x = low; - int y = i - (x ? len[x - 1] : 0); + key2slots[i] = low; + } +} + +__global__ void CopyKeysKernel(const int total_len, uint64_t** src_keys, + uint64_t* dest_total_keys, + const int64_t* slot_lens, const int* key2slot) { + CUDA_KERNEL_LOOP(i, total_len) { + int x = key2slot[i]; + int y = i - (x ? slot_lens[x - 1] : 0); dest_total_keys[i] = src_keys[x][y]; } } @@ -104,20 +104,12 @@ __global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys, template __global__ void PushCopy( boxps::FeaturePushValueGpu* dest, float** src, - int64_t* len, int hidden, int expand_dim, int slot_num, int total_len, - int bs, int* slot_vector, int* total_dims) { + int hidden, int expand_dim, int total_len, int bs, const int* slot_vector, + const int* total_dims, const int64_t* slot_lens, const int slot_num, + const int* key2slot) { CUDA_KERNEL_LOOP(i, total_len) { - int low = 0; - int high = slot_num - 1; - while (low < high) { - int mid = (low + high) / 2; - if (i < len[mid]) - high = mid; - else - low = mid + 1; - } - int x = low; - int y = i - (x ? len[low - 1] : 0); + int x = key2slot[i]; + int y = i - (x ? slot_lens[x - 1] : 0); (dest + i)->slot = slot_vector[x]; (dest + i)->show = *(src[x] + y * hidden); (dest + i)->clk = *(src[x] + y * hidden + 1); @@ -147,45 +139,61 @@ __global__ void PushCopy( } } +__device__ void add_calculator_value(const int table_size, const float pred, + const int64_t label, const int idx, + double* positive, double* negative, + double* abs_error, double* sqr_error, + double* local_pred) { + int pos = static_cast(pred * table_size); + if (pos >= table_size) { + pos = table_size - 1; + } + if (label == 0) { + atomicAdd(negative + pos, 1.0); + } else { + atomicAdd(positive + pos, 1.0); + } + double err = pred - label; + abs_error[idx] += fabs(err); + sqr_error[idx] += err * err; + local_pred[idx] += pred; +} + __global__ void AddBasicCalculator(const float* pred, const int64_t* label, double* positive, double* negative, double* abs_error, double* sqr_error, double* local_pred, int len, int table_size) { CUDA_KERNEL_LOOP(ins_idx, len) { - int pos = static_cast(pred[ins_idx] * table_size); - if (pos >= table_size) { - pos = table_size - 1; - } - if (label[ins_idx] == 0) { - atomicAdd(negative + pos, 1.0); - // negative[pos]++; - } else { - atomicAdd(positive + pos, 1.0); - // positive[pos]++; + add_calculator_value(table_size, pred[ins_idx], label[ins_idx], ins_idx, + positive, negative, abs_error, sqr_error, local_pred); + } +} + +__global__ void AddMaskCalculator(const float* pred, const int64_t* label, + const int64_t* mask, double* positive, + double* negative, double* abs_error, + double* sqr_error, double* local_pred, + int len, int table_size) { + CUDA_KERNEL_LOOP(ins_idx, len) { + if (mask[ins_idx] != 1) { + continue; } - double err = pred[ins_idx] - label[ins_idx]; - abs_error[ins_idx] += fabs(err); - sqr_error[ins_idx] += err * err; - local_pred[ins_idx] += pred[ins_idx]; + add_calculator_value(table_size, pred[ins_idx], label[ins_idx], ins_idx, + positive, negative, abs_error, sqr_error, local_pred); } } void BoxWrapper::CopyForPull(const paddle::platform::Place& place, - uint64_t** gpu_keys, - const std::vector& values, - void* total_values_gpu, const int64_t* gpu_len, - const int slot_num, const int hidden_size, - const int expand_embed_dim, + uint64_t** gpu_keys, float** gpu_values, + void* total_values_gpu, const int64_t* slot_lens, + const int slot_num, const int* key2slot, + const int hidden_size, const int expand_embed_dim, const int64_t total_length, int* total_dims) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) ->stream(); - auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*)); - float** gpu_values = reinterpret_cast(buf_value->ptr()); - cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), - cudaMemcpyHostToDevice); #define EMBEDX_CASE(i, ...) \ case i: { \ constexpr size_t EmbedxDim = i; \ @@ -205,8 +213,8 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, gpu_values, \ reinterpret_cast*>( \ total_values_gpu), \ - gpu_len, hidden_size, expand_embed_dim, slot_num, total_length, \ - gpu_keys, total_dims); \ + hidden_size, expand_embed_dim, total_length, gpu_keys, total_dims, \ + slot_lens, slot_num, key2slot); \ } break switch (hidden_size - 3) { @@ -225,48 +233,30 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, void BoxWrapper::CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys, uint64_t* total_keys, - const int64_t* gpu_len, int slot_num, int total_len) { + const int64_t* slot_lens, int slot_num, int total_len, + int* key2slot) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) ->stream(); + FillKey2Slot<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>( + total_len, slot_lens, slot_num, key2slot); CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>( - origin_keys, total_keys, gpu_len, slot_num, total_len); + total_len, origin_keys, total_keys, slot_lens, key2slot); cudaStreamSynchronize(stream); } void BoxWrapper::CopyForPush(const paddle::platform::Place& place, - const std::vector& grad_values, - void* total_grad_values_gpu, - const std::vector& slot_lengths, - const int hidden_size, const int expand_embed_dim, + float** grad_values, void* total_grad_values_gpu, + const int* d_slot_vector, const int64_t* slot_lens, + const int slot_num, const int hidden_size, + const int expand_embed_dim, const int64_t total_length, const int batch_size, - int* total_dims) { + const int* total_dims, const int* key2slot) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) ->stream(); - auto slot_lengths_lod = slot_lengths; - for (int i = 1; i < slot_lengths_lod.size(); i++) { - slot_lengths_lod[i] += slot_lengths_lod[i - 1]; - } - auto buf_grad_value = - memory::AllocShared(place, grad_values.size() * sizeof(float*)); - auto buf_length = - memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t)); - auto buf_slot_vector = - memory::AllocShared(place, slot_lengths_lod.size() * sizeof(int)); - - float** gpu_values = reinterpret_cast(buf_grad_value->ptr()); - int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); - int* d_slot_vector = reinterpret_cast(buf_slot_vector->ptr()); - - cudaMemcpy(gpu_values, grad_values.data(), - grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice); - cudaMemcpy(gpu_len, slot_lengths_lod.data(), - slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_slot_vector, slot_vector_.data(), - slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); #define EMBEDX_CASE(i, ...) \ case i: { \ @@ -279,16 +269,15 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place, } \ } break -#define EXPAND_EMBED_PUSH_CASE(i, ...) \ - case i: { \ - constexpr size_t ExpandDim = i; \ - PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \ - reinterpret_cast*>( \ - total_grad_values_gpu), \ - gpu_values, gpu_len, hidden_size, expand_embed_dim, \ - slot_lengths.size(), total_length, batch_size, d_slot_vector, \ - total_dims); \ +#define EXPAND_EMBED_PUSH_CASE(i, ...) \ + case i: { \ + constexpr size_t ExpandDim = i; \ + PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \ + reinterpret_cast*>( \ + total_grad_values_gpu), \ + grad_values, hidden_size, expand_embed_dim, total_length, batch_size, \ + d_slot_vector, total_dims, slot_lens, slot_num, key2slot); \ } break switch (hidden_size - 3) { @@ -309,7 +298,6 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place, void BasicAucCalculator::cuda_add_data(const paddle::platform::Place& place, const int64_t* label, const float* pred, int len) { - auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) @@ -327,6 +315,25 @@ void BasicAucCalculator::cuda_add_data(const paddle::platform::Place& place, reinterpret_cast(_d_pred[i]->ptr()), len, _table_size); } +void BasicAucCalculator::cuda_add_mask_data( + const paddle::platform::Place& place, const int64_t* label, + const float* pred, const int64_t* mask, int len) { + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + int i = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId(); + + cudaSetDevice(i); + + AddMaskCalculator<<<(len + 512 - 1) / 512, 512, 0, stream>>>( + pred, label, mask, reinterpret_cast(_d_positive[i]->ptr()), + reinterpret_cast(_d_negative[i]->ptr()), + reinterpret_cast(_d_abserr[i]->ptr()), + reinterpret_cast(_d_sqrerr[i]->ptr()), + reinterpret_cast(_d_pred[i]->ptr()), len, _table_size); +} + } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 07bf7b17504b21..ea641f49892c11 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -53,15 +53,19 @@ namespace framework { #ifdef PADDLE_WITH_BOX_PS class BasicAucCalculator { public: - BasicAucCalculator(bool mode_collect_in_gpu=false): - _mode_collect_in_gpu(mode_collect_in_gpu) {} - void init(int table_size, int max_batch_size=0); + explicit BasicAucCalculator(bool mode_collect_in_gpu = false) + : _mode_collect_in_gpu(mode_collect_in_gpu) {} + void init(int table_size, int max_batch_size = 0); void reset(); + // add single data in CPU with LOCK, deprecated + void add_unlock_data(double pred, int label); // add batch data void add_data(const float* d_pred, const int64_t* d_label, int batch_size, const paddle::platform::Place& place); - // add single data in CPU with LOCK, deprecated - void add_data(double pred, int label); + // add mask data + void add_mask_data(const float* d_pred, const int64_t* d_label, + const int64_t* d_mask, int batch_size, + const paddle::platform::Place& place); void compute(); int table_size() const { return _table_size; } double bucket_error() const { return _bucket_error; } @@ -76,9 +80,15 @@ class BasicAucCalculator { double& local_abserr() { return _local_abserr; } double& local_sqrerr() { return _local_sqrerr; } double& local_pred() { return _local_pred; } - void cuda_add_data(const paddle::platform::Place& place, - const int64_t* label, const float* pred, int len); + // lock and unlock + std::mutex& table_mutex(void) { return _table_mutex; } + private: + void cuda_add_data(const paddle::platform::Place& place, const int64_t* label, + const float* pred, int len); + void cuda_add_mask_data(const paddle::platform::Place& place, + const int64_t* label, const float* pred, + const int64_t* mask, int len); void calculate_bucket_error(); protected: @@ -100,9 +110,7 @@ class BasicAucCalculator { std::vector> _d_pred; private: - void set_table_size(int table_size) { - _table_size = table_size; - } + void set_table_size(int table_size) { _table_size = table_size; } void set_max_batch_size(int max_batch_size) { _max_batch_size = max_batch_size; } @@ -118,6 +126,32 @@ class BasicAucCalculator { }; class BoxWrapper { + struct DeviceBoxData { + LoDTensor keys_tensor; + LoDTensor dims_tensor; + std::shared_ptr pull_push_buf = nullptr; + std::shared_ptr gpu_keys_ptr = nullptr; + std::shared_ptr gpu_values_ptr = nullptr; + + LoDTensor slot_lens; + LoDTensor d_slot_vector; + LoDTensor keys2slot; + + platform::Timer all_pull_timer; + platform::Timer boxps_pull_timer; + platform::Timer all_push_timer; + platform::Timer boxps_push_timer; + + int64_t total_key_length = 0; + + void ResetTimer(void) { + all_pull_timer.Reset(); + boxps_pull_timer.Reset(); + all_push_timer.Reset(); + boxps_push_timer.Reset(); + } + }; + public: virtual ~BoxWrapper() { if (file_manager_ != nullptr) { @@ -132,11 +166,9 @@ class BoxWrapper { delete p_agent_; p_agent_ = nullptr; } - if (all_pull_timers_ != nullptr) { - delete[] all_pull_timers_; - delete[] boxps_pull_timers_; - delete[] all_push_timers_; - delete[] boxps_push_timers_; + if (device_caches_ != nullptr) { + delete device_caches_; + device_caches_ = nullptr; } } BoxWrapper() { @@ -180,22 +212,22 @@ class BoxWrapper { const int batch_size); void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, - const std::vector& values, void* total_values_gpu, - const int64_t* gpu_len, const int slot_num, - const int hidden_size, const int expand_embed_dim, - const int64_t total_length, int* total_dims); + float** gpu_values, void* total_values_gpu, + const int64_t* slot_lens, const int slot_num, + const int* key2slot, const int hidden_size, + const int expand_embed_dim, const int64_t total_length, + int* total_dims); - void CopyForPush(const paddle::platform::Place& place, - const std::vector& grad_values, - void* total_grad_values_gpu, - const std::vector& slot_lengths, + void CopyForPush(const paddle::platform::Place& place, float** grad_values, + void* total_grad_values_gpu, const int* slots, + const int64_t* slot_lens, const int slot_num, const int hidden_size, const int expand_embed_dim, const int64_t total_length, const int batch_size, - int* total_dims); + const int* total_dims, const int* key2slot); void CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys, uint64_t* total_keys, const int64_t* gpu_len, int slot_num, - int total_len); + int total_len, int* key2slot); void CheckEmbedSizeIsValid(int embedx_dim, int expand_embed_dim); @@ -207,7 +239,8 @@ class BoxWrapper { if (nullptr != s_instance_) { VLOG(3) << "Begin InitializeGPU"; std::vector stream_list; - for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { + int gpu_num = platform::GetCUDADeviceCount(); + for (int i = 0; i < gpu_num; ++i) { VLOG(3) << "before get context i[" << i << "]"; platform::CUDADeviceContext* context = dynamic_cast( @@ -226,15 +259,7 @@ class BoxWrapper { slot_name_omited_in_feedpass_.insert(slot_name); } slot_vector_ = slot_vector; - - int gpu_num = platform::GetCUDADeviceCount(); - keys_tensor.resize(gpu_num); - dims_tensor.resize(gpu_num); - - all_pull_timers_ = new platform::Timer[gpu_num]; - boxps_pull_timers_ = new platform::Timer[gpu_num]; - all_push_timers_ = new platform::Timer[gpu_num]; - boxps_push_timers_ = new platform::Timer[gpu_num]; + device_caches_ = new DeviceBoxData[gpu_num]; } } @@ -242,8 +267,9 @@ class BoxWrapper { void Finalize() { VLOG(3) << "Begin Finalize"; - if (nullptr != s_instance_) { + if (nullptr != s_instance_ && s_instance_->boxps_ptr_ != nullptr) { s_instance_->boxps_ptr_->Finalize(); + s_instance_->boxps_ptr_ = nullptr; } } @@ -376,8 +402,8 @@ class BoxWrapper { public: MetricMsg() {} MetricMsg(const std::string& label_varname, const std::string& pred_varname, - int metric_phase, int bucket_size = 1000000, bool mode_collect_in_gpu = false, - int max_batch_size = 0) + int metric_phase, int bucket_size = 1000000, + bool mode_collect_in_gpu = false, int max_batch_size = 0) : label_varname_(label_varname), pred_varname_(pred_varname), metric_phase_(metric_phase) { @@ -396,11 +422,11 @@ class BoxWrapper { const float* pred_data = NULL; get_data(exe_scope, label_varname_, &label_data, &label_len); get_data(exe_scope, pred_varname_, &pred_data, &pred_len); - PADDLE_ENFORCE_EQ(label_len, pred_len, platform::errors::PreconditionNotMet( - "the predict data length should be consistent with the label data length")); - int& batch_size = label_len; - auto cal = GetCalculator(); - cal->add_data(pred_data, label_data, batch_size, place); + PADDLE_ENFORCE_EQ(label_len, pred_len, + platform::errors::PreconditionNotMet( + "the predict data length should be consistent with " + "the label data length")); + calculator->add_data(pred_data, label_data, label_len, place); } template static void get_data(const Scope* exe_scope, const std::string& varname, @@ -473,7 +499,7 @@ class BoxWrapper { } virtual ~MultiTaskMetricMsg() {} void add_data(const Scope* exe_scope, - const paddle::platform::Place& place) override { + const paddle::platform::Place& place) override { std::vector cmatch_rank_data; get_data(exe_scope, cmatch_rank_varname_, &cmatch_rank_data); std::vector label_data; @@ -497,14 +523,15 @@ class BoxWrapper { batch_size, pred_data_list[i].size())); } auto cal = GetCalculator(); + std::lock_guard lock(cal->table_mutex()); for (size_t i = 0; i < batch_size; ++i) { auto cmatch_rank_it = std::find(cmatch_rank_v.begin(), cmatch_rank_v.end(), parse_cmatch_rank(cmatch_rank_data[i])); if (cmatch_rank_it != cmatch_rank_v.end()) { - cal->add_data(pred_data_list[std::distance(cmatch_rank_v.begin(), - cmatch_rank_it)][i], - label_data[i]); + cal->add_unlock_data(pred_data_list[std::distance( + cmatch_rank_v.begin(), cmatch_rank_it)][i], + label_data[i]); } } } @@ -545,7 +572,7 @@ class BoxWrapper { } virtual ~CmatchRankMetricMsg() {} void add_data(const Scope* exe_scope, - const paddle::platform::Place& place) override { + const paddle::platform::Place& place) override { std::vector cmatch_rank_data; get_data(exe_scope, cmatch_rank_varname_, &cmatch_rank_data); std::vector label_data; @@ -564,6 +591,7 @@ class BoxWrapper { "illegal batch size: cmatch_rank[%lu] and pred_data[%lu]", batch_size, pred_data.size())); auto cal = GetCalculator(); + std::lock_guard lock(cal->table_mutex()); for (size_t i = 0; i < batch_size; ++i) { const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]); for (size_t j = 0; j < cmatch_rank_v.size(); ++j) { @@ -574,7 +602,7 @@ class BoxWrapper { is_matched = cmatch_rank_v[j] == cur_cmatch_rank; } if (is_matched) { - cal->add_data(pred_data[i], label_data[i]); + cal->add_unlock_data(pred_data[i], label_data[i]); break; } } @@ -590,30 +618,35 @@ class BoxWrapper { public: MaskMetricMsg(const std::string& label_varname, const std::string& pred_varname, int metric_phase, - const std::string& mask_varname, int bucket_size = 1000000) { + const std::string& mask_varname, int bucket_size = 1000000, + bool mode_collect_in_gpu = false, int max_batch_size = 0) { label_varname_ = label_varname; pred_varname_ = pred_varname; mask_varname_ = mask_varname; metric_phase_ = metric_phase; - calculator = new BasicAucCalculator(); - calculator->init(bucket_size); + calculator = new BasicAucCalculator(mode_collect_in_gpu); + calculator->init(bucket_size, max_batch_size); } virtual ~MaskMetricMsg() {} void add_data(const Scope* exe_scope, - const paddle::platform::Place& place) override { - std::vector label_data; - get_data(exe_scope, label_varname_, &label_data); - std::vector pred_data; - get_data(exe_scope, pred_varname_, &pred_data); - std::vector mask_data; - get_data(exe_scope, mask_varname_, &mask_data); + const paddle::platform::Place& place) override { + int label_len = 0; + const int64_t* label_data = NULL; + get_data(exe_scope, label_varname_, &label_data, &label_len); + + int pred_len = 0; + const float* pred_data = NULL; + get_data(exe_scope, pred_varname_, &pred_data, &pred_len); + + int mask_len = 0; + const int64_t* mask_data = NULL; + get_data(exe_scope, mask_varname_, &mask_data, &mask_len); + PADDLE_ENFORCE_EQ(label_len, mask_len, + platform::errors::PreconditionNotMet( + "the predict data length should be consistent with " + "the label data length")); auto cal = GetCalculator(); - auto batch_size = label_data.size(); - for (size_t i = 0; i < batch_size; ++i) { - if (mask_data[i] == 1) { - cal->add_data(pred_data[i], label_data[i]); - } - } + cal->add_mask_data(pred_data, label_data, mask_data, label_len, place); } protected: @@ -658,8 +691,10 @@ class BoxWrapper { int bucket_size = 1000000, bool mode_collect_in_gpu = false, int max_batch_size = 0) { if (method == "AucCalculator") { - metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname, - metric_phase, bucket_size, mode_collect_in_gpu, max_batch_size)); + metric_lists_.emplace( + name, + new MetricMsg(label_varname, pred_varname, metric_phase, bucket_size, + mode_collect_in_gpu, max_batch_size)); } else if (method == "MultiTaskAucCalculator") { metric_lists_.emplace( name, new MultiTaskMetricMsg(label_varname, pred_varname, @@ -673,7 +708,8 @@ class BoxWrapper { } else if (method == "MaskAucCalculator") { metric_lists_.emplace( name, new MaskMetricMsg(label_varname, pred_varname, metric_phase, - mask_varname, bucket_size)); + mask_varname, bucket_size, + mode_collect_in_gpu, max_batch_size)); } else { PADDLE_THROW(platform::errors::Unimplemented( "PaddleBox only support AucCalculator, MultiTaskAucCalculator " @@ -721,15 +757,10 @@ class BoxWrapper { std::map metric_lists_; std::vector metric_name_list_; std::vector slot_vector_; - std::vector keys_tensor; // Cache for pull_sparse - std::vector dims_tensor; bool use_afs_api_ = false; std::shared_ptr file_manager_ = nullptr; - - platform::Timer* all_pull_timers_ = nullptr; - platform::Timer* boxps_pull_timers_ = nullptr; - platform::Timer* all_push_timers_ = nullptr; - platform::Timer* boxps_push_timers_ = nullptr; + // box device cache + DeviceBoxData* device_caches_ = nullptr; public: static std::shared_ptr data_shuffle_; diff --git a/paddle/fluid/framework/fleet/box_wrapper_impl.h b/paddle/fluid/framework/fleet/box_wrapper_impl.h index c9008a06baf035..da34926a6a258c 100644 --- a/paddle/fluid/framework/fleet/box_wrapper_impl.h +++ b/paddle/fluid/framework/fleet/box_wrapper_impl.h @@ -28,55 +28,73 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, // VLOG(3) << "Begin PullSparse"; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId(); - platform::Timer& all_timer = all_pull_timers_[device_id]; - platform::Timer& pull_boxps_timer = boxps_pull_timers_[device_id]; + DeviceBoxData& dev = device_caches_[device_id]; + platform::Timer& all_timer = dev.all_pull_timer; + platform::Timer& pull_boxps_timer = dev.boxps_pull_timer; #else platform::Timer all_timer; platform::Timer pull_boxps_timer; #endif all_timer.Resume(); - int64_t total_length = - std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); - auto buf = memory::AllocShared( - place, total_length * - sizeof(boxps::FeatureValueGpu)); + // construct slot_level lod info + auto slot_lengths_lod = slot_lengths; + int slot_num = static_cast(slot_lengths.size()); + for (int i = 1; i < slot_num; i++) { + slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + } + int64_t total_length = slot_lengths_lod[slot_num - 1]; + size_t total_bytes = reinterpret_cast( + total_length * + sizeof(boxps::FeatureValueGpu)); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + dev.total_key_length = total_length; + auto& pull_buf = dev.pull_push_buf; + if (pull_buf == nullptr) { + pull_buf = memory::AllocShared(place, total_bytes); + } else if (total_bytes > pull_buf->size()) { + auto buf = memory::AllocShared(place, total_bytes); + pull_buf.swap(buf); + buf = nullptr; + } +#else + auto pull_buf = memory::AllocShared(place, total_bytes); +#endif boxps::FeatureValueGpu* total_values_gpu = reinterpret_cast*>( - buf->ptr()); + pull_buf->ptr()); if (platform::is_cpu_place(place)) { PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in PaddleBox now.")); } else if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - // VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; - // int device_id = BOOST_GET_CONST(platform::CUDAPlace, - // place).GetDeviceId(); - LoDTensor& total_keys_tensor = keys_tensor[device_id]; + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); uint64_t* total_keys = reinterpret_cast( - total_keys_tensor.mutable_data({total_length, 1}, place)); + dev.keys_tensor.mutable_data({total_length, 1}, place)); int* total_dims = reinterpret_cast( - dims_tensor[device_id].mutable_data({total_length, 1}, place)); - // construct slot_level lod info - auto slot_lengths_lod = slot_lengths; - for (size_t i = 1; i < slot_lengths_lod.size(); i++) { - slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + dev.dims_tensor.mutable_data({total_length, 1}, place)); + if (dev.gpu_keys_ptr == nullptr) { + dev.gpu_keys_ptr = + memory::AllocShared(place, keys.size() * sizeof(uint64_t*)); } - auto buf_key = memory::AllocShared(place, keys.size() * sizeof(uint64_t*)); - auto buf_length = - memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t)); - uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); - int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); - cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), - cudaMemcpyHostToDevice); - cudaMemcpy(gpu_len, slot_lengths_lod.data(), - slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); - - this->CopyKeys(place, gpu_keys, total_keys, gpu_len, - static_cast(slot_lengths.size()), - static_cast(total_length)); - // VLOG(3) << "Begin call PullSparseGPU in BoxPS"; + + int* key2slot = reinterpret_cast( + dev.keys2slot.mutable_data({total_length, 1}, place)); + uint64_t** gpu_keys = reinterpret_cast(dev.gpu_keys_ptr->ptr()); + int64_t* slot_lens = reinterpret_cast( + dev.slot_lens.mutable_data({slot_num, 1}, place)); + cudaMemcpyAsync(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(slot_lens, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream); + this->CopyKeys(place, gpu_keys, total_keys, slot_lens, slot_num, + static_cast(total_length), key2slot); + pull_boxps_timer.Resume(); int ret = boxps_ptr_->PullSparseGPU( total_keys, reinterpret_cast(total_values_gpu), @@ -85,12 +103,17 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, "PullSparseGPU failed in BoxPS.")); pull_boxps_timer.Pause(); - // VLOG(3) << "Begin Copy result to tensor, total_length[" << - // total_length << "]"; - this->CopyForPull(place, gpu_keys, values, - reinterpret_cast(total_values_gpu), gpu_len, - static_cast(slot_lengths.size()), hidden_size, - expand_embed_dim, total_length, total_dims); + if (dev.gpu_values_ptr == nullptr) { + dev.gpu_values_ptr = + memory::AllocShared(place, values.size() * sizeof(float*)); + } + float** gpu_values = reinterpret_cast(dev.gpu_values_ptr->ptr()); + cudaMemcpyAsync(gpu_values, values.data(), values.size() * sizeof(float*), + cudaMemcpyHostToDevice, stream); + + this->CopyForPull(place, gpu_keys, gpu_values, total_values_gpu, slot_lens, + slot_num, key2slot, hidden_size, expand_embed_dim, + total_length, total_dims); #else PADDLE_THROW(platform::errors::PreconditionNotMet( "Please compile WITH_GPU option, because NCCL doesn't support " @@ -101,10 +124,6 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, "PaddleBox: PullSparse Only Support CPUPlace or CUDAPlace Now.")); } all_timer.Pause(); - // VLOG(1) << "PullSparse total costs: " << all_timer.ElapsedSec() - // << " s, of which BoxPS costs: " << pull_boxps_timer.ElapsedSec() - // << " s"; - // VLOG(3) << "End PullSparse"; } template @@ -114,45 +133,71 @@ void BoxWrapper::PushSparseGradCase( const std::vector& grad_values, const std::vector& slot_lengths, const int hidden_size, const int expand_embed_dim, const int batch_size) { -// VLOG(3) << "Begin PushSparseGrad"; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId(); - platform::Timer& all_timer = all_push_timers_[device_id]; - platform::Timer& push_boxps_timer = boxps_push_timers_[device_id]; + DeviceBoxData& dev = device_caches_[device_id]; + platform::Timer& all_timer = dev.all_push_timer; + platform::Timer& push_boxps_timer = dev.boxps_push_timer; #else platform::Timer all_timer; platform::Timer push_boxps_timer; #endif all_timer.Resume(); - int64_t total_length = - std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); - auto buf = memory::AllocShared( - place, +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + int64_t total_length = dev.total_key_length; + // std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + size_t total_bytes = reinterpret_cast( total_length * - sizeof(boxps::FeaturePushValueGpu)); + sizeof(boxps::FeaturePushValueGpu)); + auto& push_buf = dev.pull_push_buf; + if (push_buf == nullptr) { + push_buf = memory::AllocShared(place, total_bytes); + } else if (total_bytes > push_buf->size()) { + auto buf = memory::AllocShared(place, total_bytes); + push_buf.swap(buf); + buf = nullptr; + } +#else + auto push_buf = memory::AllocShared(place, total_bytes); +#endif boxps::FeaturePushValueGpu* total_grad_values_gpu = reinterpret_cast< boxps::FeaturePushValueGpu*>( - buf->ptr()); + push_buf->ptr()); if (platform::is_cpu_place(place)) { PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in PaddleBox now.")); } else if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - // int device_id = BOOST_GET_CONST(platform::CUDAPlace, - // place).GetDeviceId(); - LoDTensor& cached_total_keys_tensor = keys_tensor[device_id]; + auto stream = dynamic_cast( + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); uint64_t* total_keys = - reinterpret_cast(cached_total_keys_tensor.data()); - int* total_dims = - reinterpret_cast(dims_tensor[device_id].data()); - // VLOG(3) << "Begin copy grad tensor to boxps struct"; - this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths, - hidden_size, expand_embed_dim, total_length, batch_size, - total_dims); - - // VLOG(3) << "Begin call PushSparseGPU in BoxPS"; + reinterpret_cast(dev.keys_tensor.data()); + int* total_dims = reinterpret_cast(dev.dims_tensor.data()); + int slot_num = static_cast(slot_lengths.size()); + if (!dev.d_slot_vector.IsInitialized()) { + int* buf_slot_vector = reinterpret_cast( + dev.d_slot_vector.mutable_data({slot_num, 1}, place)); + cudaMemcpyAsync(buf_slot_vector, slot_vector_.data(), + slot_num * sizeof(int), cudaMemcpyHostToDevice, stream); + } + + const int64_t* slot_lens = + reinterpret_cast(dev.slot_lens.data()); + const int* d_slot_vector = dev.d_slot_vector.data(); + const int* key2slot = reinterpret_cast(dev.keys2slot.data()); + float** gpu_values = reinterpret_cast(dev.gpu_values_ptr->ptr()); + cudaMemcpyAsync(gpu_values, grad_values.data(), + grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice, + stream); + + this->CopyForPush(place, gpu_values, total_grad_values_gpu, d_slot_vector, + slot_lens, slot_num, hidden_size, expand_embed_dim, + total_length, batch_size, total_dims, key2slot); + push_boxps_timer.Resume(); int ret = boxps_ptr_->PushSparseGPU( total_keys, reinterpret_cast(total_grad_values_gpu), @@ -171,10 +216,6 @@ void BoxWrapper::PushSparseGradCase( "PaddleBox: PushSparseGrad Only Support CPUPlace or CUDAPlace Now.")); } all_timer.Pause(); - // VLOG(1) << "PushSparseGrad total cost: " << all_timer.ElapsedSec() - // << " s, of which BoxPS cost: " << push_boxps_timer.ElapsedSec() - // << " s"; - // VLOG(3) << "End PushSparseGrad"; } } // namespace framework diff --git a/paddle/fluid/operators/pull_box_sparse_op.h b/paddle/fluid/operators/pull_box_sparse_op.h index b231ed65833a02..0e48c170d82b24 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.h +++ b/paddle/fluid/operators/pull_box_sparse_op.h @@ -15,33 +15,32 @@ #pragma once #include #include +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/eigen.h" namespace paddle { namespace operators { template -static void PaddingZeros(const framework::ExecutionContext &ctx, - framework::LoDTensor* data, - int batch_size, - int hidden_size){ - // set data - data->Resize({1, hidden_size}); - data->mutable_data(ctx.GetPlace()); - auto data_eigen = framework::EigenVector::Flatten(*data); - auto &place = *ctx.template device_context() - .eigen_device(); - data_eigen.device(place) = data_eigen.constant(static_cast(0)); - - // set lod - std::vector v_lod(batch_size + 1, 1); - v_lod[0] = 0; - paddle::framework::LoD data_lod; - data_lod.push_back(v_lod); - data->set_lod(data_lod); +static void PaddingZeros(const framework::ExecutionContext &ctx, + framework::LoDTensor *data, int batch_size, + int hidden_size) { + // set data + data->Resize({1, hidden_size}); + data->mutable_data(ctx.GetPlace()); + auto data_eigen = framework::EigenVector::Flatten(*data); + auto &place = *ctx.template device_context() + .eigen_device(); + data_eigen.device(place) = data_eigen.constant(static_cast(0)); + + // set lod + std::vector v_lod(batch_size + 1, 1); + v_lod[0] = 0; + paddle::framework::LoD data_lod; + data_lod.push_back(v_lod); + data->set_lod(data_lod); } template @@ -49,18 +48,19 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { auto inputs = ctx.MultiInput("Ids"); auto outputs = ctx.MultiOutput("Out"); const auto slot_size = inputs.size(); - std::vector all_keys; + std::vector all_keys(slot_size); // BoxPS only supports float now - std::vector all_values; - std::vector slot_lengths; + std::vector all_values(slot_size); + std::vector slot_lengths(slot_size); auto hidden_size = ctx.Attr("size"); // get batch size int batch_size = -1; - for (size_t i = 0; i < slot_size; i++) { - const auto *slot = inputs[i]; - if (slot->numel() == 0) + for (size_t i = 0; i < slot_size; ++i) { + const auto *slot = inputs[i]; + if (slot->numel() == 0) { continue; + } int cur_batch_size = slot->lod().size() ? slot->lod()[0].size() - 1 : slot->dims()[0]; if (batch_size == -1) { @@ -73,20 +73,20 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { } } - for (size_t i = 0; i < slot_size; i++) { + for (size_t i = 0; i < slot_size; ++i) { const auto *slot = inputs[i]; auto *output = outputs[i]; - if (slot->numel() == 0){ - // only support GPU - PaddingZeros(ctx, output, batch_size, hidden_size); - continue; + if (slot->numel() == 0) { + // only support GPU + PaddingZeros(ctx, output, batch_size, hidden_size); + continue; } output->mutable_data(ctx.GetPlace()); const uint64_t *single_slot_keys = reinterpret_cast(slot->data()); - all_keys.push_back(single_slot_keys); - slot_lengths.push_back(slot->numel()); - all_values.push_back(output->data()); + all_keys[i] = single_slot_keys; + slot_lengths[i] = slot->numel(); + all_values[i] = output->data(); } #ifdef PADDLE_WITH_BOX_PS @@ -102,19 +102,18 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { auto d_output = ctx.MultiInput(framework::GradVarName("Out")); const auto slot_size = inputs.size(); - std::vector all_keys; - std::vector all_grad_values; - std::vector slot_lengths; + std::vector all_keys(slot_size); + std::vector all_grad_values(slot_size); + std::vector slot_lengths(slot_size); int batch_size = -1; for (size_t i = 0; i < slot_size; i++) { const auto *slot = inputs[i]; - if(slot->numel() == 0) - continue; + if (slot->numel() == 0) continue; const uint64_t *single_slot_keys = reinterpret_cast(slot->data()); - all_keys.push_back(single_slot_keys); - slot_lengths.push_back(slot->numel()); + all_keys[i] = single_slot_keys; + slot_lengths[i] = slot->numel(); int cur_batch_size = slot->lod().size() ? slot->lod()[0].size() - 1 : slot->dims()[0]; if (batch_size == -1) { @@ -126,7 +125,7 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { "please cheack")); } const float *grad_value = d_output[i]->data(); - all_grad_values.push_back(grad_value); + all_grad_values[i] = grad_value; } #ifdef PADDLE_WITH_BOX_PS auto hidden_size = ctx.Attr("size"); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index a99599af4c6209..90f12e138cca6c 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -481,3 +481,6 @@ DEFINE_int32(padbox_dataset_shuffle_thread_num, 10, "PadBoxSlotDataset shuffle thread num"); DEFINE_int32(padbox_dataset_merge_thread_num, 10, "PadBoxSlotDataset shuffle thread num"); +DEFINE_int32(padbox_slotpool_thread_num, 1, + "PadBoxSlotDataset slot pool thread num"); +