diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index ae2d09a827c7f6..f56b1bb321d2c6 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1230,6 +1230,32 @@ class LessThanOpPattern return true; } }; + +template +class LogicalCommonOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(OpType op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->template attribute(kCanRunTrtAttr).data()) { + return false; + } + pir::Value x = op.operand_source(0); + pir::Value y = op.operand_source(1); + auto x_dtype = pir::GetDataTypeFromValue(x); + auto y_dtype = pir::GetDataTypeFromValue(y); + if (!(x_dtype.isa() && y_dtype.isa())) { + VLOG(3) << "pd_op.logical_xor op only supports bool datatype"; + return false; + } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; +using LogicalXorOpPattern = + LogicalCommonOpPattern; + class MulticlassNms3OpPattern : public pir::OpRewritePattern { public: @@ -2259,6 +2285,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); diff --git a/python/paddle/tensorrt/impls/logic.py b/python/paddle/tensorrt/impls/logic.py index 4d38a06e980218..bc08a86cf48eb4 100644 --- a/python/paddle/tensorrt/impls/logic.py +++ b/python/paddle/tensorrt/impls/logic.py @@ -23,12 +23,14 @@ "pd_op.greater_than": trt.ElementWiseOperation.GREATER, "pd_op.less_than": trt.ElementWiseOperation.LESS, "pd_op.equal": trt.ElementWiseOperation.EQUAL, + "pd_op.logical_xor": trt.ElementWiseOperation.XOR, } @converter_registry.register("pd_op.greater_than", trt_version="8.x") @converter_registry.register("pd_op.less_than", trt_version="8.x") @converter_registry.register("pd_op.equal", trt_version="8.x") +@converter_registry.register("pd_op.logical_xor", trt_version="8.x") def logic_converter(network, paddle_op, inputs): layer_output = add_elementwise_layer( network, paddle_op, inputs, logic_type_map[paddle_op.name()] diff --git a/test/tensorrt/test_converter_logic.py b/test/tensorrt/test_converter_logic.py index a6fd348514966d..6e63aaf7fea159 100644 --- a/test/tensorrt/test_converter_logic.py +++ b/test/tensorrt/test_converter_logic.py @@ -140,5 +140,44 @@ def test_trt_result(self): self.check_trt_result() +class TestLogicalXorTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.logical_xor + + def test_trt_result(self): + self.api_args = { + "x": np.random.choice([True, False], size=(3,)).astype("bool"), + "y": np.random.choice([True, False], size=(3,)).astype("bool"), + } + self.program_config = {"feed_list": ["x", "y"]} + self.min_shape = {"x": [1], "y": [1]} + self.max_shape = {"x": [5], "y": [5]} + self.check_trt_result() + + def test_trt_diff_shape_result(self): + self.api_args = { + "x": np.random.choice([True, False], size=(2, 3)).astype("bool"), + "y": np.random.choice([True, False], size=(3)).astype("bool"), + } + self.program_config = {"feed_list": ["x", "y"]} + self.min_shape = {"x": [1, 3], "y": [3]} + self.max_shape = {"x": [4, 3], "y": [3]} + self.check_trt_result() + + +class TestLogicalXorMarker(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.logical_xor + self.api_args = { + "x": np.random.randn(3).astype("int64"), + "y": np.random.randn(3).astype("int64"), + } + self.program_config = {"feed_list": ["x", "y"]} + self.target_marker_op = "pd_op.logical_xor" + + def test_trt_result(self): + self.check_marker(expected_result=False) + + if __name__ == '__main__': unittest.main()