Skip to content

Commit

Permalink
[OneDNN] Add cache in Deconvolution kernel (#60922)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglirong1999 authored Jan 19, 2024
1 parent 6a8c3c5 commit 5e87a34
Showing 1 changed file with 115 additions and 35 deletions.
150 changes: 115 additions & 35 deletions paddle/phi/kernels/onednn/conv_transpose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@

namespace phi {

struct DeconvolutionCache {
dnnl::deconvolution_forward deconvolution_forward;
dnnl::memory src_mem;
dnnl::memory weights_mem;
dnnl::memory bias_mem;
dnnl::memory dst_mem;
};

inline dnnl::memory::dims GetWeightsTz(const phi::DenseTensor* filter,
const int groups) {
auto weights_tz = common::vectorize(filter->dims());
Expand Down Expand Up @@ -325,6 +333,25 @@ class ConvTransposeOneDNNHandlerT
}
};

template <typename T>
void PrepareSrcMem(const std::shared_ptr<dnnl::deconvolution_forward>& fc_p
UNUSED,
const std::shared_ptr<dnnl::memory>& src_mem,
const phi::DenseTensor* x,
const dnnl::engine& engine) {
auto x_md = x->mem_desc().reshape(src_mem->get_desc().get_dims());
if (x_md != src_mem->get_desc()) {
dnnl::memory x_mem(x_md, engine, phi::funcs::to_void_cast<T>(x->data<T>()));
auto reorder_p = dnnl::reorder(x_mem, *src_mem);

auto& astream = OneDNNContext::tls().get_stream();
reorder_p.execute(astream, x_mem, *src_mem);
astream.wait();
} else {
src_mem->set_data_handle(phi::funcs::to_void_cast<T>(x->data<T>()));
}
}

template <typename T, typename T_out>
void Execute(const OneDNNContext& dev_ctx,
const DenseTensor* x,
Expand All @@ -338,41 +365,94 @@ void Execute(const OneDNNContext& dev_ctx,
const auto* bias =
dev_ctx.HasDnnInput("Bias") ? dev_ctx.GetDnnInput("Bias") : nullptr;

ConvTransposeOneDNNHandlerT<T, float, T_out> handler(dev_ctx,
x,
filter,
bias,
strides,
paddings,
padding_algorithm,
groups,
dilations,
out);

auto src_memory_p = handler.AcquireSrcMemoryWithReorder(x);
// Caching Key for weights is needed
std::string key =
funcs::CreateKey(dev_ctx,
dev_ctx.GetInputsName("Input")[0],
dev_ctx.GetInputsName("Filter")[0],
(bias ? dev_ctx.GetInputsName("Bias")[0] : ""));
key = funcs::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto weights_memory_p =
handler.AcquireWeightsMemoryWithReorder(dev_ctx, key, filter, groups);

std::shared_ptr<dnnl::memory> dst_memory_p =
handler.template AcquireDstMemory<T_out>(out);
auto conv_p = handler.AcquireForwardPrimitive();

std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};

if (bias) {
auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(dev_ctx, key, bias);
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
std::shared_ptr<dnnl::deconvolution_forward> conv_p;
std::shared_ptr<dnnl::memory> src_memory_p;
std::shared_ptr<dnnl::memory> weights_memory_p;
std::shared_ptr<dnnl::memory> bias_memory_p;
std::shared_ptr<dnnl::memory> dst_memory_p;
std::unordered_map<int, dnnl::memory> args;

std::string cache_key = funcs::CreateKey(dev_ctx,
dev_ctx.GetInputsName("Input")[0],
dev_ctx.GetInputsName("Filter")[0],
common::vectorize(x->dims()),
common::vectorize(filter->dims()));
const auto& onednn_engine = dev_ctx.GetEngine();

auto deconvolution_cache =
std::static_pointer_cast<DeconvolutionCache>(dev_ctx.GetBlob(cache_key));
if (deconvolution_cache) {
conv_p = std::make_shared<dnnl::deconvolution_forward>(
deconvolution_cache->deconvolution_forward);

src_memory_p = std::make_shared<dnnl::memory>(deconvolution_cache->src_mem);
PrepareSrcMem<T>(conv_p, src_memory_p, x, onednn_engine);

weights_memory_p =
std::make_shared<dnnl::memory>(deconvolution_cache->weights_mem);

dst_memory_p = std::make_shared<dnnl::memory>(deconvolution_cache->dst_mem);
auto out_ptr =
dev_ctx.template Alloc<T_out>(out, dst_memory_p->get_desc().get_size());

dst_memory_p->set_data_handle(out_ptr);

args.insert({DNNL_ARG_SRC, *src_memory_p});
args.insert({DNNL_ARG_WEIGHTS, *weights_memory_p});
args.insert({DNNL_ARG_DST, *dst_memory_p});

if (bias) {
bias_memory_p =
std::make_shared<dnnl::memory>(deconvolution_cache->bias_mem);
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
} else {
// Caching Key for weights is needed
std::string key =
funcs::CreateKey(dev_ctx,
dev_ctx.GetInputsName("Input")[0],
dev_ctx.GetInputsName("Filter")[0],
(bias ? dev_ctx.GetInputsName("Bias")[0] : ""));

ConvTransposeOneDNNHandlerT<T, float, T_out> handler(dev_ctx,
x,
filter,
bias,
strides,
paddings,
padding_algorithm,
groups,
dilations,
out);

src_memory_p = handler.AcquireSrcMemoryWithReorder(x);

key = funcs::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
weights_memory_p =
handler.AcquireWeightsMemoryWithReorder(dev_ctx, key, filter, groups);

dst_memory_p = handler.template AcquireDstMemory<T_out>(out);

conv_p = handler.AcquireForwardPrimitive();

args.insert({DNNL_ARG_SRC, *src_memory_p});
args.insert({DNNL_ARG_WEIGHTS, *weights_memory_p});
args.insert({DNNL_ARG_DST, *dst_memory_p});

if (bias) {
bias_memory_p = handler.AcquireBiasMemoryWithReorder(dev_ctx, key, bias);
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
auto cache = std::make_shared<DeconvolutionCache>();
cache->deconvolution_forward = *conv_p;
cache->src_mem = *src_memory_p;
cache->weights_mem = *weights_memory_p;
cache->dst_mem = *dst_memory_p;
if (bias) {
cache->bias_mem = *bias_memory_p;
}

dev_ctx.SetBlob(cache_key, cache);
}
auto& astream = OneDNNContext::tls().get_stream();
conv_p->execute(astream, args);
Expand Down

0 comments on commit 5e87a34

Please sign in to comment.