Skip to content

Commit

Permalink
[PIR] pir onednn support mul (#61662)
Browse files Browse the repository at this point in the history
* pir onednn support mul
  • Loading branch information
wanghuancoder authored Feb 26, 2024
1 parent 0b5ae5b commit cc63252
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ void TensorNameMap(pir::Operation* op,

auto& name2id = op_yaml_info.InputName2Id();

std::string fluid_op_name = op_yaml_info.GetOriginOpName();
std::string fluid_op_name =
phi::TransToFluidOpName(op_yaml_info.OpRuntimeInfo().kernel_func);

auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();

Expand Down Expand Up @@ -327,7 +328,8 @@ OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction(
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
std::string fluid_op_name = yaml_info_parser.GetOriginOpName();
std::string fluid_op_name =
phi::TransToFluidOpName(yaml_info_parser.OpRuntimeInfo().kernel_func);

for (auto& attr : extra_args_attr) {
auto attr_name = attr.dyn_cast<pir::StrAttribute>().AsString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ OneDNNLegacyKernelInstruction::OneDNNLegacyKernelInstruction(
.AsVector();

auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
std::string fluid_op_name = yaml_info_parser.GetOriginOpName();
std::string fluid_op_name =
phi::TransToFluidOpName(yaml_info_parser.OpRuntimeInfo().kernel_func);
for (auto& attr : data_format_tensors_attr) {
auto input_name = attr.dyn_cast<pir::StrAttribute>().AsString();
data_format_tensors_.insert(
Expand All @@ -241,7 +242,8 @@ OneDNNLegacyKernelInstruction::OneDNNLegacyKernelInstruction(
.AsVector();

auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
std::string fluid_op_name = yaml_info_parser.GetOriginOpName();
std::string fluid_op_name =
phi::TransToFluidOpName(yaml_info_parser.OpRuntimeInfo().kernel_func);

for (auto& input : skip_transform_inputs) {
auto input_name = input.dyn_cast<pir::StrAttribute>().AsString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,8 @@ void BuildRuntimeContext(pir::Operation* op,

auto& name2id = op_yaml_info.InputName2Id();

std::string fluid_op_name = op_yaml_info.GetOriginOpName();
std::string fluid_op_name =
phi::TransToFluidOpName(op_yaml_info.OpRuntimeInfo().kernel_func);

auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();

Expand Down Expand Up @@ -890,7 +891,8 @@ std::shared_ptr<OperatorBase> BuildOperatorBase(

auto& name2id = op_yaml_info.InputName2Id();

std::string fluid_op_name = op_yaml_info.GetOriginOpName();
std::string fluid_op_name =
phi::TransToFluidOpName(op_yaml_info.OpRuntimeInfo().kernel_func);

auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();

Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,19 @@ ValueInfo GetTensorInfoByVarName(const OpDesc& op_desc,
}

struct MulOpTranscriber : public OpTranscriber {
pir::Operation* operator()(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
pir::Block* block) override {
#ifdef PADDLE_WITH_DNNL
if (op_desc.GetAttrIfExists<bool>("use_mkldnn")) {
return static_cast<OpTranscriber>(*this).operator()(
ctx, param_map, op_desc, block);
}
#endif
return OpTranscriber::operator()(ctx, param_map, op_desc, block);
}

pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
const std::string& target_op_name = paddle::dialect::MatmulOp::name();
Expand Down Expand Up @@ -1605,6 +1618,19 @@ struct MulOpTranscriber : public OpTranscriber {
};

struct MulGradOpTranscriber : public OpTranscriber {
pir::Operation* operator()(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
pir::Block* block) override {
#ifdef PADDLE_WITH_DNNL
if (op_desc.GetAttrIfExists<bool>("use_mkldnn")) {
return static_cast<OpTranscriber>(*this).operator()(
ctx, param_map, op_desc, block);
}
#endif
return OpTranscriber::operator()(ctx, param_map, op_desc, block);
}

pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
const std::string& target_op_name = paddle::dialect::MatmulGradOp::name();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
'onednn_to_paddle_layout',
'lrn',
'multi_gru',
'matmul_with_flatten',
]

NO_NEED_GEN_STATIC_ONLY_APIS = [
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,16 @@
backward : matmul_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : matmul_with_flatten
args : (Tensor x, Tensor y, int x_num_col_dims = 1, int y_num_col_dims = 1)
output : Tensor
infer_meta :
func : MatmulWithFlattenInferMeta
kernel :
func : matmul_with_flatten
data_type : x
backward : matmul_with_flatten_grad

- op : matrix_rank
args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,16 @@
func : matmul_grad
backward : matmul_double_grad

- backward_op : matmul_with_flatten_grad
forward : matmul_with_flatten (Tensor x, Tensor y, int x_num_col_dims=1, int y_num_col_dims=1) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int x_num_col_dims=1, int y_num_col_dims=1)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : matmul_with_flatten_grad

- backward_op : max_grad
forward: max (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false)
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,11 @@
extra_args : str mkldnn_data_type="float32"
data_format_tensors : x, y, out_grad

# - op : matmul_with_flatten
- op : matmul_with_flatten
extra_args : float scale_x=1.0, float[] scale_y={1.0}, float scale_out=1.0, bool force_fp32_output=false

# - op : matmul_with_flatten_grad
- op : matmul_with_flatten_grad
extra_args : float scale_x=1.0, float[] scale_y={1.0}, float scale_out=1.0, bool force_fp32_output=false

- op : max
dynamic_fallback : True
Expand Down
2 changes: 1 addition & 1 deletion test/mkldnn/test_mul_int8_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def init_data(self):
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output_with_place(
core.CPUPlace(), atol=0, check_dygraph=False
core.CPUPlace(), atol=0, check_dygraph=False, check_pir_onednn=True
)


Expand Down
26 changes: 22 additions & 4 deletions test/mkldnn/test_mul_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,28 @@ def init_inputs_dtype(self):
pass

def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
self.check_output_with_place(
core.CPUPlace(), check_pir_onednn=True, check_dygraph=False
)

def test_check_grad(self):
self.check_grad_with_place(core.CPUPlace(), ['X', 'Y'], 'Out')
self.check_grad_with_place(
core.CPUPlace(),
['X', 'Y'],
'Out',
check_pir_onednn=True,
check_dygraph=False,
)

def test_check_grad_ignore_x(self):
self.check_grad_with_place(core.CPUPlace(), ['Y'], 'Out', set('X'))
self.check_grad_with_place(
core.CPUPlace(), ['Y'], 'Out', set('X'), check_pir_onednn=True
)

def test_check_grad_ignore_y(self):
self.check_grad_with_place(core.CPUPlace(), ['X'], 'Out', set('Y'))
self.check_grad_with_place(
core.CPUPlace(), ['X'], 'Out', set('Y'), check_pir_onednn=True
)


class TestMulXNumColDims2OneDNNOp(TestMulOneDNNOp):
Expand Down Expand Up @@ -135,6 +147,8 @@ def test_check_grad(self):
'Out',
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)],
check_pir_onednn=True,
check_dygraph=False,
)

def test_check_grad_ignore_x(self):
Expand All @@ -146,6 +160,8 @@ def test_check_grad_ignore_x(self):
set('X'),
user_defined_grads=[self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)],
check_pir_onednn=True,
check_dygraph=False,
)

def test_check_grad_ignore_y(self):
Expand All @@ -157,6 +173,8 @@ def test_check_grad_ignore_y(self):
set('Y'),
user_defined_grads=[self.dx],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)],
check_pir_onednn=True,
check_dygraph=False,
)


Expand Down

0 comments on commit cc63252

Please sign in to comment.