Skip to content

Commit

Permalink
Add symbolic reciprocal
Browse files Browse the repository at this point in the history
  • Loading branch information
hxzd5568 committed Apr 23, 2024
1 parent ff8ef18 commit 90524d9
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 0 deletions.
49 changes: 49 additions & 0 deletions paddle/cinn/hlir/op/contrib/reciprocal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,53 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
return strategy;
}

std::shared_ptr<OpStrategy> StrategyForReciprocalSymbolic(
const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
std::string op_name("reciprocal");

framework::CINNCompute reciprocal_compute(
[=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name
<< " compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
CHECK(!pack_args.empty())
<< "at least one input tensor for " << op_name << " compute\n";

CHECK_EQ(pack_args.size(), 2);
CHECK(pack_args[1].is_string());
std::string tensor_name = pack_args[1].operator std::string();

Expr A = pack_args[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
auto tensor_A = A.as_tensor_ref();
auto stages = CreateStages({tensor_A});
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");

CHECK_EQ(pack_args.size(), 2U);
tensor_name = pack_args[1].operator std::string();

ir::Tensor out = Reciprocal(tensor_A, tensor_name);
std::vector<CINNValue> res;
stages->InsertLazily(out);
res.push_back(CINNValue(out));
CHECK(!out_type.empty())
<< "Output type of Reciprocal is empty! Please check.\n";
res.push_back(CINNValue(stages));
*ret = CINNValuePack{res};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
reciprocal_compute, lang::PackedFunc(), "strategy.reciprocal.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForReciprocal(
const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
Expand Down Expand Up @@ -153,6 +200,8 @@ CINN_REGISTER_HELPER(reciprocal_ops) {
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>(
"CINNStrategy", cinn::hlir::op::StrategyForReciprocal)
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
"CINNStrategySymbolic", cinn::hlir::op::StrategyForReciprocalSymbolic)
.set_attr("infershape",
MakeOpFunction(cinn::hlir::op::InferShapeForReciprocal))
.set_attr("inferdtype",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ OP_SAME_OPERANDS_AND_RESULT(Print)
OP_SAME_OPERANDS_AND_RESULT(PutAlongAxis)
OP_SAME_OPERANDS_AND_RESULT(PutAlongAxis_)
OP_SAME_OPERANDS_AND_RESULT(Real)
OP_SAME_OPERANDS_AND_RESULT(Reciprocal)
OP_SAME_OPERANDS_AND_RESULT(Reciprocal_)
OP_SAME_OPERANDS_AND_RESULT(Relu)
OP_SAME_OPERANDS_AND_RESULT(Relu6)
OP_SAME_OPERANDS_AND_RESULT(Relu_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Print)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PutAlongAxis)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PutAlongAxis_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Real)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reciprocal)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reciprocal_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu6)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu_)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2308,6 +2308,7 @@
func : reciprocal
inplace : (x -> out)
backward : reciprocal_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : reduce_as
args : (Tensor x, Tensor target)
Expand Down
40 changes: 40 additions & 0 deletions test/ir/pir/cinn/symbolic/test_cinn_elementwise_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def tril(x):
return paddle.tril(x)


def reciprocal(x):
return paddle.reciprocal(x)


def isinf(x):
return paddle.isinf(x)

Expand Down Expand Up @@ -411,5 +415,41 @@ def test_eval_symbolic(self):
np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)


class TestCinnSubGrapReciprocal(unittest.TestCase):
"""
Test Pir API + @to_static + CINN.
"""

def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.x_shape = [32, 32]
self.x = paddle.randn(self.x_shape, dtype="float32")
self.x.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)

def eval_symbolic(self, use_cinn):
paddle.seed(2022)
net = CINNSubGraphNet(reciprocal)
input_spec = [
InputSpec(shape=[None, 32], dtype='float32'),
]
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval_symbolic(self):
cinn_out = self.eval_symbolic(use_cinn=True)
dy_out = self.eval_symbolic(use_cinn=False)
np.testing.assert_allclose(cinn_out.numpy(), dy_out.numpy(), atol=1e-8)


if __name__ == '__main__':
unittest.main()

0 comments on commit 90524d9

Please sign in to comment.