Skip to content

Commit

Permalink
remove some describe
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroRains committed Oct 25, 2023
1 parent 1a622bf commit 25a3817
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 54 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
data_type : x

- op : fusion_gru
args : (Tensor x, Tensor h0, Tensor weight_x, Tensor weight_h, Tensor bias, str activation = "tanh", str gate_activation = "sigmoid", bool is_reverse = false, bool use_seq = true, bool origin_mode = false, bool use_onednn = false, str onednn_data_type = "float32", float scale_data = 1.0f, float shift_data = 0.0f, float[] scale_weights = {1.0f}, bool force_fp32_output = false)
args : (Tensor x, Tensor h0, Tensor weight_x, Tensor weight_h, Tensor bias, str activation = "tanh", str gate_activation = "sigmoid", bool is_reverse = false, bool use_seq = true, bool origin_mode = false, bool use_mkldnn = false, str mkldnn_data_type = "float32", float scale_data = 1.0f, float shift_data = 0.0f, float[] scale_weights = {1.0f}, bool force_fp32_output = false)
output : Tensor(reordered_h0), Tensor(xx), Tensor(batched_input), Tensor(batched_out), Tensor(hidden)
infer_meta :
func : FusionGRUInferMeta
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1339,8 +1339,6 @@
batched_out : BatchedOut
hidden : Hidden
attrs :
use_onednn : use_mkldnn
onednn_data_type : mkldnn_data_type
scale_data : Scale_data
shift_data : Shift_data
scale_weights : Scale_weights
Expand Down
16 changes: 8 additions & 8 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2176,8 +2176,8 @@ void FusionGRUInferMeta(const MetaTensor& x,
const bool is_reverse,
const bool use_seq,
const bool origin_mode,
const bool use_onednn,
const std::string& onednn_data_type,
const bool use_mkldnn,
const std::string& mkldnn_data_type,
const float scale_data,
const float shift_data,
const std::vector<float>& scale_weights,
Expand All @@ -2187,15 +2187,15 @@ void FusionGRUInferMeta(const MetaTensor& x,
MetaTensor* batched_input,
MetaTensor* batched_out,
MetaTensor* hidden) {
std::string onednn_data_type_list[] = {"float32", "int8", "bfloat16"};
std::string mkldnn_data_type_list[] = {"float32", "int8", "bfloat16"};
PADDLE_ENFORCE_EQ(
std::find(std::begin(onednn_data_type_list),
std::end(onednn_data_type_list),
onednn_data_type) != std::end(onednn_data_type_list),
std::find(std::begin(mkldnn_data_type_list),
std::end(mkldnn_data_type_list),
mkldnn_data_type) != std::end(mkldnn_data_type_list),
true,
phi::errors::InvalidArgument("The onednn_data_type shoule be [float32, "
phi::errors::InvalidArgument("The mkldnn_data_type shoule be [float32, "
"int8, bfloat16], but found %s.",
onednn_data_type.c_str()));
mkldnn_data_type.c_str()));

DDim x_dims = x.dims();
auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ void FusionGRUInferMeta(const MetaTensor& x,
const bool is_reverse,
const bool use_seq,
const bool origin_mode,
const bool use_onednn,
const std::string& onednn_data_type,
const bool use_mkldnn,
const std::string& mkldnn_data_type,
const float scale_data,
const float shift_data,
const std::vector<float>& scale_weights,
Expand Down
24 changes: 12 additions & 12 deletions paddle/phi/kernels/fusion/cpu/fusion_gru_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ void SeqCompute(const Context& dev_ctx,
const bool is_reverse,
const bool use_seq,
const bool origin_mode,
const bool use_onednn,
const std::string& onednn_data_type,
const bool use_mkldnn,
const std::string& mkldnn_data_type,
const float scale_data,
const float shift_data,
const std::vector<float>& scale_weights,
Expand Down Expand Up @@ -184,8 +184,8 @@ void BatchCompute(const Context& dev_ctx,
const bool is_reverse,
const bool use_seq,
const bool origin_mode,
const bool use_onednn,
const std::string& onednn_data_type,
const bool use_mkldnn,
const std::string& mkldnn_data_type,
const float scale_data,
const float shift_data,
const std::vector<float>& scale_weights,
Expand All @@ -209,8 +209,8 @@ void BatchCompute(const Context& dev_ctx,
is_reverse,
use_seq,
origin_mode,
use_onednn,
onednn_data_type,
use_mkldnn,
mkldnn_data_type,
scale_data,
shift_data,
scale_weights,
Expand Down Expand Up @@ -372,8 +372,8 @@ void FusionGRUKernel(const Context& dev_ctx,
const bool is_reverse,
const bool use_seq,
const bool origin_mode,
const bool use_onednn,
const std::string& onednn_data_type,
const bool use_mkldnn,
const std::string& mkldnn_data_type,
const float scale_data,
const float shift_data,
const std::vector<float>& scale_weights,
Expand All @@ -395,8 +395,8 @@ void FusionGRUKernel(const Context& dev_ctx,
is_reverse,
use_seq,
origin_mode,
use_onednn,
onednn_data_type,
use_mkldnn,
mkldnn_data_type,
scale_data,
shift_data,
scale_weights,
Expand All @@ -418,8 +418,8 @@ void FusionGRUKernel(const Context& dev_ctx,
is_reverse,
use_seq,
origin_mode,
use_onednn,
onednn_data_type,
use_mkldnn,
mkldnn_data_type,
scale_data,
shift_data,
scale_weights,
Expand Down
58 changes: 29 additions & 29 deletions paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ void RunKernel(const phi::OneDNNContext& dev_ctx,
const bool is_reverse,
const bool use_seq,
const bool origin_mode,
const bool use_onednn,
const std::string& onednn_data_type,
const bool use_mkldnn,
const std::string& mkldnn_data_type,
const float scale_data,
const float shift_data,
const std::vector<float>& scale_weights,
Expand Down Expand Up @@ -553,28 +553,28 @@ void RunKernel(const phi::OneDNNContext& dev_ctx,
}

template <typename T, typename Context>
void FusionGRUOneDNNKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& h0,
const DenseTensor& weight_x,
const DenseTensor& weight_h,
const paddle::optional<DenseTensor>& bias,
const std::string& activation,
const std::string& gate_activation,
const bool is_reverse,
const bool use_seq,
const bool origin_mode,
const bool use_onednn,
const std::string& onednn_data_type,
const float scale_data,
const float shift_data,
const std::vector<float>& scale_weights,
const bool force_fp32_output,
DenseTensor* reordered_h0,
DenseTensor* xx,
DenseTensor* batched_input,
DenseTensor* batched_out,
DenseTensor* hidden) {
void FusionGRUKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& h0,
const DenseTensor& weight_x,
const DenseTensor& weight_h,
const paddle::optional<DenseTensor>& bias,
const std::string& activation,
const std::string& gate_activation,
const bool is_reverse,
const bool use_seq,
const bool origin_mode,
const bool use_mkldnn,
const std::string& mkldnn_data_type,
const float scale_data,
const float shift_data,
const std::vector<float>& scale_weights,
const bool force_fp32_output,
DenseTensor* reordered_h0,
DenseTensor* xx,
DenseTensor* batched_input,
DenseTensor* batched_out,
DenseTensor* hidden) {
const bool is_bf16 = std::is_same<T, phi::dtype::bfloat16>::value;
// BF16 does not support force output
if (!is_bf16 && force_fp32_output) { // NOLINT
Expand All @@ -589,8 +589,8 @@ void FusionGRUOneDNNKernel(const Context& dev_ctx,
is_reverse,
use_seq,
origin_mode,
use_onednn,
onednn_data_type,
use_mkldnn,
mkldnn_data_type,
scale_data,
shift_data,
scale_weights,
Expand All @@ -612,8 +612,8 @@ void FusionGRUOneDNNKernel(const Context& dev_ctx,
is_reverse,
use_seq,
origin_mode,
use_onednn,
onednn_data_type,
use_mkldnn,
mkldnn_data_type,
scale_data,
shift_data,
scale_weights,
Expand All @@ -632,7 +632,7 @@ void FusionGRUOneDNNKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(fusion_gru,
OneDNN,
ONEDNN,
phi::fusion::FusionGRUOneDNNKernel,
phi::fusion::FusionGRUKernel,
float,
phi::dtype::bfloat16,
uint8_t) {}

0 comments on commit 25a3817

Please sign in to comment.