From 187fb27f5864487bc8edc41b51d7d6e5332efd01 Mon Sep 17 00:00:00 2001 From: zerorains Date: Thu, 26 Oct 2023 10:57:24 +0000 Subject: [PATCH] add the set_dtype --- paddle/phi/infermeta/fusion.cc | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index c1d943893741b6..effa82f19e4ea5 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1915,6 +1915,7 @@ void FusedEmbeddingEltWiseLayerNormInferMeta( auto dim_output = phi::make_ddim({batch, seq_len, hidden}); out->set_dims(dim_output); out->share_lod(*ids[0]); + out->set_dtype((*embs[0]).dtype()); } void FusionTransposeFlattenConcatInferMeta( @@ -1976,6 +1977,7 @@ void FusionTransposeFlattenConcatInferMeta( out_dims[concat_axis] = -1; } out->set_dims(phi::make_ddim(out_dims)); + out->set_dtype((*x[0]).dtype()); } void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, @@ -2157,11 +2159,14 @@ void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x, } out->set_dims(y_dims); + out->set_dtype(x.dtype()); if (mean) { + mean->set_dtype(x.dtype()); mean->set_dims({dim_0}); } if (variance) { variance->set_dims({dim_0}); + variance->set_dtype(x.dtype()); } out->share_lod(x); } @@ -2265,6 +2270,7 @@ void FusionGRUInferMeta(const MetaTensor& x, "receiced the width of H0 is:%d, frame size is:%d", h0_dims[1], frame_size)); + reordered_h0->set_dtype(x.dtype()); } if (bias) { auto b_dims = bias.dims(); @@ -2293,6 +2299,7 @@ void FusionGRUInferMeta(const MetaTensor& x, DDim out_dims({x_mat_dims[0], frame_size}); hidden->set_dims(out_dims); hidden->share_lod(x); + hidden->set_dtype(x.dtype()); int xx_width = 0; if (use_seq) { xx_width = static_cast(wx_dims[1]); @@ -2300,9 +2307,12 @@ void FusionGRUInferMeta(const MetaTensor& x, xx_width = static_cast(x_mat_dims[1] > wx_dims[1] ? wx_dims[1] : x_mat_dims[1]); batched_input->set_dims({x_mat_dims[0], wx_dims[1]}); + batched_input->set_dtype(x.dtype()); batched_out->set_dims(out_dims); + batched_out->set_dtype(x.dtype()); } xx->set_dims({x_mat_dims[0], xx_width}); + xx->set_dtype(x.dtype()); xx->share_lod(x); } @@ -2365,6 +2375,8 @@ void FusionSeqConvEltAddReluInferMeta(const MetaTensor& x, out->set_dims({x_dims[0], w_dims[1]}); col_mat->set_dims({x_dims[0], w_dims[0]}); out->share_lod(x); + col_mat->set_dtype(x.dtype()); + out->set_dtype(x.dtype()); } void FusionSeqExpandConcatFCInferMeta(const std::vector& x, @@ -2440,10 +2452,11 @@ void FusionSeqExpandConcatFCInferMeta(const std::vector& x, b_dims[1])); } } + fc_out->set_dtype((*x[0]).dtype()); out->set_dims({ins_dims[0][0], D}); + out->set_dtype((*x[0]).dtype()); // fcout should be reshape when run since can not get lod in infershape // explicit share the ref lod - // ctx->ShareLoD("X", "Out", 0); out->share_lod(*x[0]); } } // namespace phi