Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify fetch logic, use D2H Stream #35191

Merged
merged 3 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
main_program_(main_prog),
global_scope_(global_scope),
d2h_ctx_pool_({place}),
h2d_ctx_pool_({place}),
fetch_context_pool_({place}) {
h2d_ctx_pool_({place}) {
is_build_ = false;

garbages_.reset(new GarbageQueue());
Expand Down Expand Up @@ -348,9 +347,6 @@ void InterpreterCore::BuildInstructionCtx(Instruction* instr_node,
new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get()));

auto* dev_ctx = instr_node->dev_ctx_;
if (instr_node->kernel_func_.operator_base_->Type() == "fetch_v2") {
dev_ctx = fetch_context_pool_.Get(place);
}
Scope scope;

instr_node->execution_ctx_.reset(new ExecutionContext(
Expand All @@ -362,12 +358,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
instr_node.kernel_func_.operator_base_)
->InferShape(instr_node.infershape_ctx_.get());

if (instr_node.kernel_func_.operator_base_->Type() == "fetch_v2") {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place_);
dev_ctx->Wait(); // TODO(wanghuancoder)
}

instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get());
}

Expand Down Expand Up @@ -412,8 +402,6 @@ void InterpreterCore::ExecuteInstructionList(
working_var_ref);
}

fetch_context_pool_.Get(place)->Wait();

for (size_t i = 0; i < working_var_ref.size(); ++i) {
if (working_var_ref[i].var_ref_count_ != 0) {
std::cerr << " var ref is not zero " << i << std::endl;
Expand Down Expand Up @@ -672,6 +660,9 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place,
expected_kernel_key);
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_)) {
if (op_base->Type() == "fetch_v2") {
op_base->SetAttr("deepcopy", false);
}
// need trans place
// 1. add var in scope
// 2. add copy op
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ class InterpreterCore {
size_t max_memory_size_;
size_t cur_memory_size_;
std::unique_ptr<WorkQueue> gc_queue_;

platform::DeviceContextPool fetch_context_pool_;
};
} // namespace framework
} // namespace paddle
75 changes: 30 additions & 45 deletions paddle/fluid/operators/controlflow/fetch_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ struct float16;
namespace paddle {
namespace operators {

static void DataCopy(const framework::LoDTensor &src_item,
static void DeepCopy(const framework::LoDTensor &src_item,
const std::string &fetch_var_name,
framework::LoDTensor *dst_item,
const platform::DeviceContext &dev_ctx) {
framework::LoDTensor *dst_item) {
if (src_item.IsInitialized() && src_item.numel() > 0) {
#ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle
Expand All @@ -53,26 +52,13 @@ static void DataCopy(const framework::LoDTensor &src_item,
: paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout(),
src_item, &out, platform::CPUPlace());
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, dst_item);
TensorCopySync(out, platform::CPUPlace(), dst_item);
} else {
if (platform::is_gpu_place(src_item.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item);
#endif
} else {
TensorCopy(src_item, platform::CPUPlace(), dst_item);
}
TensorCopySync(src_item, platform::CPUPlace(), dst_item);
}
#else
if (platform::is_gpu_place(src_item.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item);
TensorCopySync(src_item, platform::CPUPlace(), dst_item);
#endif
} else {
TensorCopy(src_item, platform::CPUPlace(), dst_item);
}
#endif

} else {
// Not copy, if the src tensor is empty.
dst_item->clear();
Expand All @@ -92,15 +78,14 @@ class FetchV2Op : public framework::OperatorWithKernel {
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
tensor.place(), tensor.layout());
}

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
platform::CPUPlace());
}
};

Expand All @@ -119,12 +104,10 @@ class FetchV2Kernel {
if (fetch_var == nullptr) {
return;
}
PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of memcpy_d2h_op is not found."));
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of fetch_v2_op is not found."));
auto *out_var = ctx.OutputVar("Out");
// Get dev_ctx from ExecutionContext, it's D2H stream
auto &dev_ctx = ctx.device_context();

int col = ctx.Attr<int>("col");
PADDLE_ENFORCE_GE(
Expand All @@ -140,18 +123,34 @@ class FetchV2Kernel {
fetch_list->resize(col + 1);
}

bool deepcopy = ctx.Attr<bool>("deepcopy");

if (fetch_var->IsType<framework::LoDTensor>()) {
auto &src_item = fetch_var->Get<framework::LoDTensor>();
auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col)));
DataCopy(src_item, fetch_var_name, dst_item, dev_ctx);
PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item.place()), true,
platform::errors::InvalidArgument(
"Tensor's place of input(X) must be CPUPlace."));
if (deepcopy) {
DeepCopy(src_item, fetch_var_name, dst_item);
} else {
dst_item->ShareDataWith(src_item);
}
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
framework::LoDTensorArray tmp(src_item.size());
fetch_list->at(col) = tmp;
auto &dst_item =
BOOST_GET(framework::LoDTensorArray, fetch_list->at(col));
for (size_t i = 0; i < src_item.size(); ++i) {
DataCopy(src_item[i], fetch_var_name, &dst_item[i], dev_ctx);
PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item[i].place()), true,
platform::errors::InvalidArgument(
"Tensor's place of input(X) must be CPUPlace."));
if (deepcopy) {
DeepCopy(src_item[i], fetch_var_name, &dst_item[i]);
} else {
dst_item[i].ShareDataWith(src_item[i]);
}
}
}
}
Expand All @@ -167,6 +166,8 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker {
"(vector<LoDTensor>) A fetching list of LoDTensor which may have "
"different dimension, shape and data type.");
AddAttr<int>("col", "(int) The column index of fetching object.");
AddAttr<bool>("deepcopy", "(bool) Whether deep copy is required.")
.SetDefault(true);
AddComment(R"DOC(
FetchV2 Operator.

Expand All @@ -192,19 +193,3 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
ops::FetchV2Kernel, int, ops::FetchV2Kernel,
int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel);
#endif

#ifdef PADDLE_WITH_ASCEND_CL
REGISTER_OP_NPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
ops::FetchV2Kernel, int, ops::FetchV2Kernel,
int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel);
#endif