diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 71198107456030..130f4e7dad4ec8 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1453,13 +1453,14 @@ class StackOpPattern : public pir::OpRewritePattern { } }; -class TanhOpPattern : public pir::OpRewritePattern { +template +class ActOpPattern : public pir::OpRewritePattern { public: - using pir::OpRewritePattern::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::TanhOp op, + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(OpType op, pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && - op.attribute(kCanRunTrtAttr).data()) { + op->template attribute(kCanRunTrtAttr).data()) { return false; } #if IS_TRT_VERSION_LT(8600) @@ -1477,6 +1478,8 @@ class TanhOpPattern : public pir::OpRewritePattern { return true; } }; +using TanhOpPattern = ActOpPattern; +using SoftplusOpPatten = ActOpPattern; class WherePattern : public pir::OpRewritePattern { public: @@ -1783,6 +1786,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); return ps; diff --git a/python/paddle/tensorrt/impls/activation.py b/python/paddle/tensorrt/impls/activation.py index f01c12cf30baef..a81a23e369baaa 100644 --- a/python/paddle/tensorrt/impls/activation.py +++ b/python/paddle/tensorrt/impls/activation.py @@ -104,6 +104,23 @@ def hardswish_converter(network, paddle_op, inputs): return hardswish_layer.get_output(0) +@converter_registry.register("pd_op.softplus", trt_version="8.x") +def softplus_converter(network, paddle_op, inputs): + x = inputs[0] + beta = paddle_op.attrs()["beta"] + threshold = paddle_op.attrs()["threshold"] + layer_clip = network.add_activation(x, trt.ActivationType.CLIP) + layer_clip.alpha = -3.40282e038 + layer_clip.beta = threshold / beta + + softplus_layer = network.add_activation( + layer_clip.get_output(0), trt.ActivationType.SOFTPLUS + ) + softplus_layer.alpha = 1.0 / beta + softplus_layer.beta = beta + return softplus_layer.get_output(0) + + @converter_registry.register("pd_op.swish", trt_version="8.x") @converter_registry.register("pd_op.silu", trt_version="8.x") def swish_silu_converter(network, paddle_op, inputs): diff --git a/test/tensorrt/test_converter_activation.py b/test/tensorrt/test_converter_activation.py index 170359e210dc45..b36110c4abcdb0 100644 --- a/test/tensorrt/test_converter_activation.py +++ b/test/tensorrt/test_converter_activation.py @@ -48,10 +48,10 @@ def test_trt_result(self): self.check_trt_result() -class TestRELUTRTPattern(TensorRTBaseTest): +class TestReluTRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.nn.functional.relu - self.api_args = {"x": np.random.randn(3).astype(np.float32)} + self.api_args = {"x": np.random.randn(3).astype("float32")} self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [1]} self.max_shape = {"x": [5]} @@ -60,10 +60,10 @@ def test_trt_result(self): self.check_trt_result() -class TestTANHTRTPattern(TensorRTBaseTest): +class TestTanhTRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.tanh - self.api_args = {"x": np.random.randn(3).astype(np.float32)} + self.api_args = {"x": np.random.randn(3).astype("float32")} self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [1]} self.max_shape = {"x": [5]} @@ -76,11 +76,25 @@ class TestSigmoidTRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.nn.functional.sigmoid self.api_args = { - "x": np.random.randn(2, 3).astype(np.float32), + "x": np.random.randn(2, 3).astype("float32"), } self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [1, 3], "y": [1, 3]} - self.max_shape = {"x": [5, 3], "y": [5, 3]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestSoftplusTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.Softplus() + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} def test_trt_result(self): self.check_trt_result()