Skip to content

Commit

Permalink
Paddlebox Optimize memory allocation box pull push and pack batch (Pa…
Browse files Browse the repository at this point in the history
  • Loading branch information
qingshui authored Jul 29, 2020
1 parent f2da3b4 commit 321c0a3
Show file tree
Hide file tree
Showing 10 changed files with 833 additions and 547 deletions.
157 changes: 112 additions & 45 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1842,7 +1842,7 @@ bool SlotPaddleBoxDataFeed::Start() {
this->finish_start_ = true;
#if defined(PADDLE_WITH_CUDA) && defined(_LINUX)
CHECK(paddle::platform::is_gpu_place(this->place_));
pack_ = new MiniBatchGpuPack(this->GetPlace(), used_slots_info_);
pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_);
#endif
return true;
}
Expand All @@ -1852,23 +1852,45 @@ int SlotPaddleBoxDataFeed::Next() {
if (offset_index_ >= static_cast<int>(batch_offsets_.size())) {
return 0;
}

auto& batch = batch_offsets_[offset_index_++];
if (enable_pv_merge_ && phase == 1) {
// join phase : output_pv_channel to consume_pv_channel
this->batch_size_ = batch.second;
if (this->batch_size_ != 0) {
batch_timer_.Resume();
PutToFeedPvVec(&pv_ins_[batch.first], this->batch_size_);
batch_timer_.Pause();
} else {
VLOG(3) << "finish reading, batch size zero, thread_id=" << thread_id_;
}
return this->batch_size_;
} else {
this->batch_size_ = batch.second;
batch_timer_.Resume();
PutToFeedSlotVec(&records_[batch.first], this->batch_size_);
batch_timer_.Pause();
return this->batch_size_;
}
}
bool SlotPaddleBoxDataFeed::EnablePvMerge(void) {
return (enable_pv_merge_ && GetCurrentPhase() == 1);
}
int SlotPaddleBoxDataFeed::GetPackInstance(SlotRecord** ins) {
if (offset_index_ >= static_cast<int>(batch_offsets_.size())) {
return 0;
}
auto& batch = batch_offsets_[offset_index_];
*ins = &records_[batch.first];
return batch.second;
}
int SlotPaddleBoxDataFeed::GetPackPvInstance(SlotPvInstance** pv_ins) {
if (offset_index_ >= static_cast<int>(batch_offsets_.size())) {
return 0;
}
auto& batch = batch_offsets_[offset_index_];
*pv_ins = &pv_ins_[batch.first];
return batch.second;
}
void SlotPaddleBoxDataFeed::AssignFeedVar(const Scope& scope) {
CheckInit();
for (int i = 0; i < use_slot_size_; ++i) {
Expand All @@ -1886,8 +1908,10 @@ void SlotPaddleBoxDataFeed::PutToFeedPvVec(const SlotPvInstance* pvs, int num) {
paddle::platform::SetDeviceId(
boost::get<platform::CUDAPlace>(place_).GetDeviceId());
pack_->pack_pvinstance(pvs, num);
GetRankOffsetGPU();
BuildSlotBatchGPU();
int ins_num = pack_->ins_num();
int pv_num = pack_->pv_num();
GetRankOffsetGPU(pv_num, ins_num);
BuildSlotBatchGPU(ins_num);
#else
int ins_number = 0;
std::vector<SlotRecord> ins_vec;
Expand Down Expand Up @@ -1971,7 +1995,7 @@ void SlotPaddleBoxDataFeed::PutToFeedSlotVec(const SlotRecord* ins_vec,
paddle::platform::SetDeviceId(
boost::get<platform::CUDAPlace>(place_).GetDeviceId());
pack_->pack_instance(ins_vec, num);
BuildSlotBatchGPU();
BuildSlotBatchGPU(pack_->ins_num());
#else
for (int j = 0; j < use_slot_size_; ++j) {
auto& feed = feed_vec_[j];
Expand Down Expand Up @@ -2056,32 +2080,47 @@ void SlotPaddleBoxDataFeed::PutToFeedSlotVec(const SlotRecord* ins_vec,
// LOG(WARNING) << "[" << name << "]" << ostream.str();
//}

void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(void) {
void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(const int ins_num) {
#if defined(PADDLE_WITH_CUDA) && defined(_LINUX)
int ins_num = pack_->ins_num();
fill_timer_.Resume();

int offset_cols_size = (ins_num + 1);
slot_value_offsets_.resize(use_slot_size_ * offset_cols_size);
auto gpu_slot_offsets = memory::AllocShared(
this->GetPlace(), (use_slot_size_ * offset_cols_size) * sizeof(size_t));
size_t slot_total_bytes =
(use_slot_size_ * offset_cols_size) * sizeof(size_t);
if (gpu_slot_offsets_ == nullptr) {
gpu_slot_offsets_ = memory::AllocShared(this->GetPlace(), slot_total_bytes);
} else if (gpu_slot_offsets_->size() < slot_total_bytes) {
auto buf = memory::AllocShared(this->GetPlace(), slot_total_bytes);
gpu_slot_offsets_.swap(buf);
buf = nullptr;
}

auto& value = pack_->value();
const UsedSlotGpuType* used_slot_gpu_types =
static_cast<const UsedSlotGpuType*>(pack_->get_gpu_slots());
FillSlotValueOffset(&slot_value_offsets_, ins_num, use_slot_size_,
reinterpret_cast<size_t*>(gpu_slot_offsets->ptr()),
FillSlotValueOffset(ins_num, use_slot_size_,
reinterpret_cast<size_t*>(gpu_slot_offsets_->ptr()),
value.d_uint64_offset.data(), uint64_use_slot_size_,
value.d_float_offset.data(), float_use_slot_size_,
used_slot_gpu_types);
fill_timer_.Pause();
size_t* d_slot_offsets = reinterpret_cast<size_t*>(gpu_slot_offsets_->ptr());

offset_timer_.Resume();
thread_local std::vector<size_t> offsets;
offsets.resize(offset_cols_size);
thread_local HostBuffer<void*> h_tensor_ptrs;
h_tensor_ptrs.resize(use_slot_size_);

std::vector<size_t> offsets(offset_cols_size, 0);
std::vector<void*> h_tensor_ptrs(use_slot_size_, nullptr);
for (int j = 0; j < use_slot_size_; ++j) {
auto& feed = feed_vec_[j];
if (feed == nullptr) {
h_tensor_ptrs[j] = nullptr;
continue;
}
memcpy(offsets.data(), &slot_value_offsets_[j * offset_cols_size],
offset_cols_size * sizeof(size_t));

cudaMemcpy(offsets.data(), &d_slot_offsets[j * offset_cols_size],
offset_cols_size * sizeof(size_t), cudaMemcpyDeviceToHost);
int total_instance = offsets.back();
CHECK(total_instance >= 0) << "slot idx:" << j
<< ", total instance:" << total_instance;
Expand All @@ -2094,9 +2133,7 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(void) {
h_tensor_ptrs[j] =
feed->mutable_data<int64_t>({total_instance, 1}, this->place_);
}

LoD data_lod{offsets};
feed_vec_[j]->set_lod(data_lod);
feed->set_lod({offsets});

if (info.dense) {
if (info.inductive_shape_index != -1) {
Expand All @@ -2106,24 +2143,27 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(void) {
feed->Resize(framework::make_ddim(info.local_shape));
}
}
offset_timer_.Pause();

auto buf =
memory::AllocShared(this->GetPlace(), use_slot_size_ * sizeof(void*));
void** dest_gpu_p = reinterpret_cast<void**>(buf->ptr());
// fprintf(stderr, "after create tensor\n");
// pack_->copy2tensor_ptr(h_tensor_ptrs);
trans_timer_.Resume();
if (slot_buf_ptr_ == nullptr) {
slot_buf_ptr_ =
memory::AllocShared(this->GetPlace(), use_slot_size_ * sizeof(void*));
}
void** dest_gpu_p = reinterpret_cast<void**>(slot_buf_ptr_->ptr());
cudaMemcpy(dest_gpu_p, h_tensor_ptrs.data(), use_slot_size_ * sizeof(void*),
cudaMemcpyHostToDevice);

CopyForTensor(ins_num, use_slot_size_, dest_gpu_p,
(const size_t*)gpu_slot_offsets->ptr(),
(const size_t*)gpu_slot_offsets_->ptr(),
(const uint64_t*)value.d_uint64_keys.data(),
(const int*)value.d_uint64_offset.data(),
(const int*)value.d_uint64_lens.data(), uint64_use_slot_size_,
(const float*)value.d_float_keys.data(),
(const int*)value.d_float_offset.data(),
(const int*)value.d_float_lens.data(), float_use_slot_size_,
used_slot_gpu_types);
trans_timer_.Pause();
#endif
}
int SlotPaddleBoxDataFeed::GetCurrentPhase() {
Expand All @@ -2134,16 +2174,15 @@ int SlotPaddleBoxDataFeed::GetCurrentPhase() {
return box_ptr->Phase();
}
}
void SlotPaddleBoxDataFeed::GetRankOffsetGPU(void) {
void SlotPaddleBoxDataFeed::GetRankOffsetGPU(const int pv_num,
const int ins_num) {
#if defined(PADDLE_WITH_CUDA) && defined(_LINUX)
int max_rank = 3; // the value is setting
int col = max_rank * 2 + 1;
int ins_num = pack_->ins_num();

auto& value = pack_->value();
int* tensor_ptr =
rank_offset_->mutable_data<int>({ins_num, col}, this->place_);
CopyRankOffset(tensor_ptr, ins_num, pack_->pv_num(), max_rank,
CopyRankOffset(tensor_ptr, ins_num, pv_num, max_rank,
(const int*)value.d_rank.data(),
(const int*)value.d_cmatch.data(),
(const int*)value.d_ad_offset.data(), col);
Expand Down Expand Up @@ -2631,9 +2670,24 @@ bool SlotPaddleBoxDataFeed::ParseOneInstance(const std::string& line,

////////////////////////////// pack ////////////////////////////////////
#if defined(PADDLE_WITH_CUDA) && defined(_LINUX)
SlotPaddleBoxDataFeed::MiniBatchGpuPack::MiniBatchGpuPack(
const paddle::platform::Place& place,
const std::vector<UsedSlotInfo>& infos) {
static void SetCPUAffinity(int tid) {
std::vector<int>& cores = boxps::get_train_cores();
if (cores.empty()) {
VLOG(0) << "not found binding read ins thread cores";
return;
}

size_t core_num = cores.size() / 2;
if (core_num < 8) {
return;
}
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(cores[core_num + (tid % core_num)], &mask);
pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
}
MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place,
const std::vector<UsedSlotInfo>& infos) {
place_ = place;
// paddle::platform::SetDeviceId(boost::get<platform::CUDAPlace>(place).GetDeviceId());
// paddle::platform::CUDADeviceContext* context =
Expand Down Expand Up @@ -2661,13 +2715,26 @@ SlotPaddleBoxDataFeed::MiniBatchGpuPack::MiniBatchGpuPack(
++used_float_num_;
}
}
copy_host2device(&gpu_slots_, gpu_used_slots_);
copy_host2device(&gpu_slots_, gpu_used_slots_.data(), gpu_used_slots_.size());
}

SlotPaddleBoxDataFeed::MiniBatchGpuPack::~MiniBatchGpuPack() {}
MiniBatchGpuPack::~MiniBatchGpuPack() {}

void MiniBatchGpuPack::reset(const paddle::platform::Place& place) {
place_ = place;
stream_ = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
boost::get<platform::CUDAPlace>(place)))
->stream();
ins_num_ = 0;
pv_num_ = 0;
enable_pv_ = false;

pack_timer_.Reset();
trans_timer_.Reset();
}

void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_pvinstance(
const SlotPvInstance* pv_ins, int num) {
void MiniBatchGpuPack::pack_pvinstance(const SlotPvInstance* pv_ins, int num) {
pv_num_ = num;
buf_.h_ad_offset.resize(num + 1);
buf_.h_ad_offset[0] = 0;
Expand All @@ -2689,8 +2756,7 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_pvinstance(
pack_instance(&ins_vec_[0], ins_number);
}

void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_all_data(
const SlotRecord* ins_vec, int num) {
void MiniBatchGpuPack::pack_all_data(const SlotRecord* ins_vec, int num) {
int uint64_total_num = 0;
int float_total_num = 0;

Expand Down Expand Up @@ -2759,8 +2825,7 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_all_data(
CHECK(float_total_num == static_cast<int>(buf_.h_float_lens.back()))
<< "float value length error";
}
void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_uint64_data(
const SlotRecord* ins_vec, int num) {
void MiniBatchGpuPack::pack_uint64_data(const SlotRecord* ins_vec, int num) {
int uint64_total_num = 0;

buf_.h_float_lens.clear();
Expand Down Expand Up @@ -2809,8 +2874,7 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_uint64_data(
CHECK(uint64_total_num == static_cast<int>(buf_.h_uint64_lens.back()))
<< "uint64 value length error";
}
void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_float_data(
const SlotRecord* ins_vec, int num) {
void MiniBatchGpuPack::pack_float_data(const SlotRecord* ins_vec, int num) {
int float_total_num = 0;

buf_.h_uint64_lens.clear();
Expand Down Expand Up @@ -2859,8 +2923,8 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_float_data(
<< "float value length error";
}

void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_instance(
const SlotRecord* ins_vec, int num) {
void MiniBatchGpuPack::pack_instance(const SlotRecord* ins_vec, int num) {
pack_timer_.Resume();
ins_num_ = num;
CHECK(used_uint64_num_ > 0 || used_float_num_ > 0);
// uint64 and float
Expand All @@ -2871,11 +2935,13 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::pack_instance(
} else { // only float
pack_float_data(ins_vec, num);
}
pack_timer_.Pause();
// to gpu
transfer_to_gpu();
}

void SlotPaddleBoxDataFeed::MiniBatchGpuPack::transfer_to_gpu(void) {
void MiniBatchGpuPack::transfer_to_gpu(void) {
trans_timer_.Resume();
if (enable_pv_) {
copy_host2device(&value_.d_ad_offset, buf_.h_ad_offset);
copy_host2device(&value_.d_rank, buf_.h_rank);
Expand All @@ -2889,6 +2955,7 @@ void SlotPaddleBoxDataFeed::MiniBatchGpuPack::transfer_to_gpu(void) {
copy_host2device(&value_.d_float_keys, buf_.h_float_keys);
copy_host2device(&value_.d_float_offset, buf_.h_float_offset);
cudaStreamSynchronize(stream_);
trans_timer_.Pause();
}
#endif

Expand Down
30 changes: 13 additions & 17 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace framework {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)

__global__ void CopyForTensorKernel(FeatureItem* src, void** dest,
size_t* offset, char* type,
__global__ void CopyForTensorKernel(FeatureItem *src, void **dest,
size_t *offset, char *type,
size_t total_size, size_t row_size,
size_t col_size) {
CUDA_KERNEL_LOOP(i, row_size * col_size) {
Expand All @@ -42,14 +42,14 @@ __global__ void CopyForTensorKernel(FeatureItem* src, void** dest,
right = offset[row_id * (col_size + 1) + col_id + 1] -
offset[(row_id - 1) * (col_size + 1) + col_id + 1];
}
uint64_t* up = NULL;
float* fp = NULL;

uint64_t *up = NULL;
float *fp = NULL;
if (type[row_id] == 'f') {
fp = reinterpret_cast<float*>(dest[row_id]);
fp = reinterpret_cast<float *>(dest[row_id]);
} else {
up = reinterpret_cast<uint64_t*>(
*(reinterpret_cast<uint64_t**>(dest) + row_id));
up = reinterpret_cast<uint64_t *>(
*(reinterpret_cast<uint64_t **>(dest) + row_id));
}
size_t begin = offset[row_id * (col_size + 1) + col_id + 1] +
offset[(row_size - 1) * (col_size + 1) + col_id] -
Expand All @@ -72,10 +72,10 @@ __global__ void CopyForTensorKernel(FeatureItem* src, void** dest,
}

void MultiSlotInMemoryDataFeed::CopyForTensor(
const paddle::platform::Place& place, FeatureItem* src, void** dest,
size_t* offset, char* type, size_t total_size, size_t row_size,
const paddle::platform::Place &place, FeatureItem *src, void **dest,
size_t *offset, char *type, size_t total_size, size_t row_size,
size_t col_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
auto stream = dynamic_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
boost::get<platform::CUDAPlace>(place)))
->stream();
Expand All @@ -91,7 +91,7 @@ const int CUDA_NUM_THREADS = 512;
inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}

// fill slot values
__global__ void FillSlotValueOffsetKernel(
const int ins_num, const int used_slot_num, size_t *slot_value_offsets,
const int *uint64_offsets, const int uint64_slot_size,
Expand Down Expand Up @@ -127,8 +127,7 @@ __global__ void FillSlotValueOffsetKernel(
}

void SlotPaddleBoxDataFeed::FillSlotValueOffset(
std::vector<size_t> *cpu_slot_value_offsets, const int ins_num,
const int used_slot_num, size_t *slot_value_offsets,
const int ins_num, const int used_slot_num, size_t *slot_value_offsets,
const int *uint64_offsets, const int uint64_slot_size,
const int *float_offsets, const int float_slot_size,
const UsedSlotGpuType *used_slots) {
Expand All @@ -140,9 +139,6 @@ void SlotPaddleBoxDataFeed::FillSlotValueOffset(
stream>>>(
ins_num, used_slot_num, slot_value_offsets, uint64_offsets,
uint64_slot_size, float_offsets, float_slot_size, used_slots);
cudaMemcpyAsync(cpu_slot_value_offsets->data(), slot_value_offsets,
cpu_slot_value_offsets->size() * sizeof(size_t),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
}

Expand Down
Loading

0 comments on commit 321c0a3

Please sign in to comment.