Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#48 from YaoCheng8667/paddlebox-yc
Browse files Browse the repository at this point in the history
optimize datafeed by CPU async
  • Loading branch information
YaoCheng8667 authored Feb 1, 2024
2 parents 22296f2 + 1a5192c commit be6d0dd
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 22 deletions.
243 changes: 226 additions & 17 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License. */

USE_INT_STAT(STAT_total_feasign_num_in_mem);
DECLARE_bool(enable_ins_parser_file);
DECLARE_bool(enable_async_datafeed_batch);
#ifdef PADDLE_WITH_BOX_PS
#include <dlfcn.h>
extern "C" {
Expand Down Expand Up @@ -3163,6 +3164,10 @@ bool SlotPaddleBoxDataFeed::Start() {
#elif defined(PADDLE_WITH_XPU_KP) && !defined(CPU_DATA_FEED)
pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_);
#endif
if (FLAGS_enable_async_datafeed_batch) {
slot_pv_tensor_buf_ = std::make_shared<MiniBatchSlotPvTensorBuffer>(used_slots_info_, this->GetPlace());
slot_pv_tensor_buf_next_ = std::make_shared<MiniBatchSlotPvTensorBuffer>(used_slots_info_, this->GetPlace());
}
return true;
}

Expand Down Expand Up @@ -3190,6 +3195,17 @@ int SlotPaddleBoxDataFeed::Next() {
} else {
VLOG(3) << "finish reading, batch size zero, thread_id=" << thread_id_;
}

if (FLAGS_enable_async_datafeed_batch && offset_index_ < static_cast<int>(batch_offsets_.size())) {
auto & new_batch = batch_offsets_[offset_index_];
if (new_batch.second != 0) {
std::future<bool> prefetch_done = std::async(std::launch::async,
std::bind(&SlotPaddleBoxDataFeed::PrefechNextBatchWithPv,
this, &pv_ins_[new_batch.first], new_batch.second));
slot_pv_tensor_buf_next_->set_buffer_done(std::move(prefetch_done));
}
}

#ifdef PADDLE_WITH_XPU_KP
CHECK(prepare_next_batch_rt.get() == 0);
#endif
Expand All @@ -3200,6 +3216,15 @@ int SlotPaddleBoxDataFeed::Next() {
this->batch_size_ = batch.second;
batch_timer_.Resume();
PutToFeedSlotVec(&records_[batch.first], this->batch_size_);
// prefetch next batch
if (FLAGS_enable_async_datafeed_batch && offset_index_ < static_cast<int>(batch_offsets_.size())) {
auto & new_batch = batch_offsets_[offset_index_];
//PrefechNextBatch(&records_[new_batch.first], new_batch.second);
std::future<bool> prefetch_done = std::async(std::launch::async,
std::bind(&SlotPaddleBoxDataFeed::PrefechNextBatch, this, &records_[new_batch.first], new_batch.second));
slot_pv_tensor_buf_next_->set_buffer_done(std::move(prefetch_done));
}

#if defined(PADDLE_WITH_CUDA) && defined(_LINUX)
// update set join q value
if (FLAGS_padbox_slotrecord_extend_dim > 0) {
Expand Down Expand Up @@ -3253,11 +3278,33 @@ void SlotPaddleBoxDataFeed::PutToFeedPvVec(const SlotPvInstance* pvs, int num) {
BuildSlotBatchGPU(ins_num);
#elif defined(PADDLE_WITH_XPU_KP) && !defined(CPU_DATA_FEED)
paddle::platform::SetXPUDeviceId(place_.GetDeviceId());
pack_->pack_pvinstance(pvs, num);
int ins_num = pack_->ins_num();
int pv_num = pack_->pv_num();
GetRankOffsetGPU(pv_num, ins_num);
BuildSlotBatchGPU(ins_num);
if (FLAGS_enable_async_datafeed_batch) {
if (!slot_pv_tensor_buf_next_->valid()) { // first_batch
std::future<bool> prefetch_done = std::async(std::launch::async,
std::bind(&SlotPaddleBoxDataFeed::PrefechNextBatchWithPv, this, pvs, num));
slot_pv_tensor_buf_next_->set_buffer_done(std::move(prefetch_done));
}

slot_pv_tensor_buf_next_->wait_buffer_done();
std::swap(slot_pv_tensor_buf_next_, slot_pv_tensor_buf_);
auto & bufferd_feed_vec = slot_pv_tensor_buf_->feed_vec();

for (int j = 0; j < use_slot_size_; ++j) {
feed_vec_[j]->ShareDataWith(bufferd_feed_vec[j]);
feed_vec_[j]->set_lod(*(bufferd_feed_vec[j].mutable_lod()));
}

LoDTensor & pv_buffered_tensor = slot_pv_tensor_buf_->pv_tensor();
rank_offset_->ShareDataWith(pv_buffered_tensor);
rank_offset_->set_lod(*(pv_buffered_tensor.mutable_lod()));

} else {
pack_->pack_pvinstance(pvs, num);
int ins_num = pack_->ins_num();
int pv_num = pack_->pv_num();
GetRankOffsetGPU(pv_num, ins_num);
BuildSlotBatchGPU(ins_num);
}
#else
int ins_number = 0;
pv_ins_vec_.clear();
Expand Down Expand Up @@ -3335,17 +3382,171 @@ void SlotPaddleBoxDataFeed::ExpandSlotRecord(SlotRecord* rec) {
CHECK(float_total_dims_size_ == static_cast<size_t>(offset));
}

bool SlotPaddleBoxDataFeed::PrefechNextBatchWithPv(const SlotPvInstance* pvs, int num) {
int ins_number = 0;
std::vector<SlotRecord> & pv_ins_vec = slot_pv_tensor_buf_next_->pv_ins_vec();
pv_ins_vec.clear();
for (int i = 0; i < num; ++i) {
auto& pv = pvs[i];
ins_number += pv->ads.size();
for (auto ins : pv->ads) {
pv_ins_vec.push_back(ins);
}
}
int max_rank = 3; // the value is setting
int row = ins_number;
int col = max_rank * 2 + 1;

std::vector<int> rank_offset_mat(row * col, -1);
rank_offset_mat.shrink_to_fit();

CalRankOffsetCPU(pvs, num, ins_number, rank_offset_mat, max_rank, row, col);
int * rank_offset = rank_offset_mat.data();
slot_pv_tensor_buf_next_->resize_pv_tensor(row, col);
int * tensor_ptr = slot_pv_tensor_buf_next_->pv_tensor().data<int>();
CopyToFeedTensor(tensor_ptr, rank_offset, row * col * sizeof(int));

return PrefechNextBatch(&pv_ins_vec[0], ins_number);
}

bool SlotPaddleBoxDataFeed::PrefechNextBatch(const SlotRecord* ins_vec, int num) {
int uint64_total_len = 0, float_total_len = 0;

auto & slot_float_feas = slot_pv_tensor_buf_next_->batch_float_feasigns();
auto & slot_uint64_feas = slot_pv_tensor_buf_next_->batch_uint64_feasigns();
auto & slot_offsets = slot_pv_tensor_buf_next_->offsets();

for (int j = 0; j < use_slot_size_; ++j) {
auto& slot_offset = slot_offsets[j];
slot_offset.clear();
slot_offset.reserve(num + 1);
slot_offset.emplace_back(0);

int slot_total_instance = 0;
auto & info = used_slots_info_[j];
// fill slot value with default value 0
if (info.type[0] == 'f') { // float
auto& batch_fea = slot_float_feas[j];
batch_fea.clear();

for (int i = 0; i < num; ++i) {
auto & r = ins_vec[i];
size_t fea_num = 0;
float* slot_values = r->slot_float_feasigns_.get_values(info.slot_value_idx, &fea_num);
if (fea_num > 0) {
float_total_len += fea_num;
batch_fea.resize(slot_total_instance + fea_num);
memcpy(&batch_fea[slot_total_instance], slot_values,
sizeof(float) * fea_num);
slot_total_instance += fea_num;
}
slot_offset.push_back(slot_total_instance);
}

} else if (info.type[0] == 'u') { // uint64
auto& batch_fea = slot_uint64_feas[j];
batch_fea.clear();

for (int i = 0; i < num; ++i) {
auto & r = ins_vec[i];
size_t fea_num = 0;
uint64_t* slot_values = r->slot_uint64_feasigns_.get_values(info.slot_value_idx, &fea_num);
if (fea_num > 0) {
batch_fea.resize(slot_total_instance + fea_num);
uint64_total_len += fea_num;
memcpy(&batch_fea[slot_total_instance], slot_values,
sizeof(uint64_t) * fea_num);
slot_total_instance += fea_num;
}
slot_offset.push_back(slot_total_instance);
}
}
}

// alloc mem
slot_pv_tensor_buf_next_->resize_tensor(float_total_len, uint64_total_len);

// Copy to feed tensor & add lod
LoDTensor & float_tensor = slot_pv_tensor_buf_next_->float_tensor();
LoDTensor & uint64_tensor = slot_pv_tensor_buf_next_->uint64_tensor();
std::vector<LoDTensor> & feed_vec = slot_pv_tensor_buf_next_->feed_vec();
feed_vec.clear();
feed_vec.resize(use_slot_size_);

size_t slot_uint64_offset = 0;
size_t slot_float_offset = 0;

// shared buffer
for (int j = 0; j < use_slot_size_; ++j) {
size_t slot_total_len = slot_offsets[j][slot_offsets[j].size() - 1];

int feedvec_len = (slot_total_len == 0 ? 1 : slot_total_len);

auto& info = used_slots_info_[j];
if (info.type[0] == 'f') {
feed_vec[j].ShareDataWith(float_tensor.Slice(static_cast<int64_t>(slot_float_offset),
static_cast<int64_t>(slot_float_offset + feedvec_len)));
slot_float_offset += feedvec_len;
if (slot_total_len > 0) {
CopyToFeedTensor(feed_vec[j].data<float>(), &slot_float_feas[j][0], slot_total_len * sizeof(float));
}
} else {
feed_vec[j].ShareDataWith(uint64_tensor.Slice(static_cast<int64_t>(slot_uint64_offset),
static_cast<int64_t>(slot_uint64_offset + feedvec_len)));
slot_uint64_offset += feedvec_len;
if (slot_total_len > 0) {
CopyToFeedTensor(feed_vec[j].data<int64_t>(), &slot_uint64_feas[j][0], slot_total_len * sizeof(int64_t));
}
}
feed_vec[j].Resize({static_cast<long int>(slot_total_len), 1});
if (info.dense) {
if (info.inductive_shape_index != -1) {
info.local_shape[info.inductive_shape_index] = slot_total_len / info.total_dims_without_inductive;
}
feed_vec[j].Resize(phi::make_ddim(info.local_shape));
} else {
LoD& lod = *(feed_vec[j].mutable_lod());
lod.resize(1);
lod[0].resize(num + 1);
paddle::framework::MixVector<size_t> mixv_lod(&lod[0]);
memcpy(mixv_lod.MutableData(platform::CPUPlace()),
&slot_offsets[j][0],
(num + 1) * sizeof(size_t));
}
}
return true;
}

void SlotPaddleBoxDataFeed::PutToFeedSlotVec(const SlotRecord* ins_vec,
int num) {
#if defined(PADDLE_WITH_CUDA) && defined(_LINUX)
paddle::platform::SetDeviceId(place_.GetDeviceId());
pack_->pack_instance(ins_vec, num);
BuildSlotBatchGPU(pack_->ins_num());

#elif defined(PADDLE_WITH_XPU_KP) && !defined(CPU_DATA_FEED)

paddle::platform::SetXPUDeviceId(place_.GetDeviceId());
pack_->pack_instance(ins_vec, num);
BuildSlotBatchGPU(pack_->ins_num());
#else

if (FLAGS_enable_async_datafeed_batch) {
if (!slot_pv_tensor_buf_next_->valid()) { // first_batch
std::future<bool> prefetch_done = std::async(std::launch::async,
std::bind(&SlotPaddleBoxDataFeed::PrefechNextBatch, this, ins_vec, num));
slot_pv_tensor_buf_next_->set_buffer_done(std::move(prefetch_done));
}
slot_pv_tensor_buf_next_->wait_buffer_done();
std::swap(slot_pv_tensor_buf_next_, slot_pv_tensor_buf_);
auto & bufferd_feed_vec = slot_pv_tensor_buf_->feed_vec();
for (int j = 0; j < use_slot_size_; ++j) {
feed_vec_[j]->ShareDataWith(bufferd_feed_vec[j]);
feed_vec_[j]->set_lod(*(bufferd_feed_vec[j].mutable_lod()));
}
} else {
pack_->pack_instance(ins_vec, num);
BuildSlotBatchGPU(pack_->ins_num());
}

#else // by cpu
batch_ins_num_ = num;
ins_record_ptr_ = ins_vec;
for (int j = 0; j < use_slot_size_; ++j) {
Expand Down Expand Up @@ -3625,16 +3826,11 @@ void SlotPaddleBoxDataFeed::GetRankOffsetGPU(const int pv_num,
#endif
#endif
}
void SlotPaddleBoxDataFeed::GetRankOffset(const SlotPvInstance* pv_vec,
int pv_num, int ins_number) {
int index = 0;
int max_rank = 3; // the value is setting
int row = ins_number;
int col = max_rank * 2 + 1;

std::vector<int> rank_offset_mat(row * col, -1);
rank_offset_mat.shrink_to_fit();

void SlotPaddleBoxDataFeed::CalRankOffsetCPU(const SlotPvInstance* pv_vec,
int pv_num, int ins_number, std::vector<int> & rank_offset_mat,
int max_rank, int row, int col) {
int index = 0;
for (int i = 0; i < pv_num; i++) {
auto pv_ins = pv_vec[i];
int ad_num = pv_ins->ads.size();
Expand Down Expand Up @@ -3668,6 +3864,19 @@ void SlotPaddleBoxDataFeed::GetRankOffset(const SlotPvInstance* pv_vec,
index += 1;
}
}
}

void SlotPaddleBoxDataFeed::GetRankOffset(const SlotPvInstance* pv_vec,
int pv_num, int ins_number) {

int max_rank = 3; // the value is setting
int row = ins_number;
int col = max_rank * 2 + 1;

std::vector<int> rank_offset_mat(row * col, -1);
rank_offset_mat.shrink_to_fit();

CalRankOffsetCPU(pv_vec, pv_num, ins_number, rank_offset_mat, max_rank, row, col);

int* rank_offset = rank_offset_mat.data();
int* tensor_ptr = rank_offset_->mutable_data<int>({row, col}, this->place_);
Expand Down
Loading

0 comments on commit be6d0dd

Please sign in to comment.