Skip to content

Commit

Permalink
add the set_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroRains committed Oct 26, 2023
1 parent 25a3817 commit 187fb27
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1915,6 +1915,7 @@ void FusedEmbeddingEltWiseLayerNormInferMeta(
auto dim_output = phi::make_ddim({batch, seq_len, hidden});
out->set_dims(dim_output);
out->share_lod(*ids[0]);
out->set_dtype((*embs[0]).dtype());
}

void FusionTransposeFlattenConcatInferMeta(
Expand Down Expand Up @@ -1976,6 +1977,7 @@ void FusionTransposeFlattenConcatInferMeta(
out_dims[concat_axis] = -1;
}
out->set_dims(phi::make_ddim(out_dims));
out->set_dtype((*x[0]).dtype());
}

void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -2157,11 +2159,14 @@ void FusedFCElementwiseLayerNormInferMeta(const MetaTensor& x,
}

out->set_dims(y_dims);
out->set_dtype(x.dtype());
if (mean) {
mean->set_dtype(x.dtype());
mean->set_dims({dim_0});
}
if (variance) {
variance->set_dims({dim_0});
variance->set_dtype(x.dtype());
}
out->share_lod(x);
}
Expand Down Expand Up @@ -2265,6 +2270,7 @@ void FusionGRUInferMeta(const MetaTensor& x,
"receiced the width of H0 is:%d, frame size is:%d",
h0_dims[1],
frame_size));
reordered_h0->set_dtype(x.dtype());
}
if (bias) {
auto b_dims = bias.dims();
Expand Down Expand Up @@ -2293,16 +2299,20 @@ void FusionGRUInferMeta(const MetaTensor& x,
DDim out_dims({x_mat_dims[0], frame_size});
hidden->set_dims(out_dims);
hidden->share_lod(x);
hidden->set_dtype(x.dtype());
int xx_width = 0;
if (use_seq) {
xx_width = static_cast<int>(wx_dims[1]);
} else {
xx_width = static_cast<int>(x_mat_dims[1] > wx_dims[1] ? wx_dims[1]
: x_mat_dims[1]);
batched_input->set_dims({x_mat_dims[0], wx_dims[1]});
batched_input->set_dtype(x.dtype());
batched_out->set_dims(out_dims);
batched_out->set_dtype(x.dtype());
}
xx->set_dims({x_mat_dims[0], xx_width});
xx->set_dtype(x.dtype());
xx->share_lod(x);
}

Expand Down Expand Up @@ -2365,6 +2375,8 @@ void FusionSeqConvEltAddReluInferMeta(const MetaTensor& x,
out->set_dims({x_dims[0], w_dims[1]});
col_mat->set_dims({x_dims[0], w_dims[0]});
out->share_lod(x);
col_mat->set_dtype(x.dtype());
out->set_dtype(x.dtype());
}

void FusionSeqExpandConcatFCInferMeta(const std::vector<const MetaTensor*>& x,
Expand Down Expand Up @@ -2440,10 +2452,11 @@ void FusionSeqExpandConcatFCInferMeta(const std::vector<const MetaTensor*>& x,
b_dims[1]));
}
}
fc_out->set_dtype((*x[0]).dtype());
out->set_dims({ins_dims[0][0], D});
out->set_dtype((*x[0]).dtype());
// fcout should be reshape when run since can not get lod in infershape
// explicit share the ref lod
// ctx->ShareLoD("X", "Out", 0);
out->share_lod(*x[0]);
}
} // namespace phi

0 comments on commit 187fb27

Please sign in to comment.