diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index d82a87d4e08111..a10ea3dbc1d4c5 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -63,6 +63,7 @@ DEFINE_GENERAL_PATTERN(Fused_gemm_epilogue, paddle::dialect::FusedGemmEpilogueOp) DEFINE_GENERAL_PATTERN(Layer_norm, paddle::dialect::LayerNormOp) DEFINE_GENERAL_PATTERN(Add, paddle::dialect::AddOp) +DEFINE_GENERAL_PATTERN(Isnan, paddle::dialect::IsnanOp) DEFINE_GENERAL_PATTERN(Full, paddle::dialect::FullOp) DEFINE_GENERAL_PATTERN(Silu, paddle::dialect::SiluOp) DEFINE_GENERAL_PATTERN(Conv2d, paddle::dialect::Conv2dOp) @@ -1409,6 +1410,29 @@ class ArgsortOpPattern return true; } }; + +class EmbeddingOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::EmbeddingOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op.attribute(kCanRunTrtAttr).data()) { + return false; + } + // if (pir::GetDefiningOpForInput(op, 1)->name() == "builtin.parameter") { + // // trt.Weights don't have the shape info. + // VLOG(3) << "Skip to convert into TRT while found weight is a parameter + // " + // "in pd_op.embedding."; + // return false; + // } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; + class BilinearInterpV2Pattern : public pir::OpRewritePattern { public: @@ -2192,6 +2216,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ADD_PATTERN(AssignOut) ADD_PATTERN(Assign) ADD_PATTERN(Tile) + ADD_PATTERN(Isnan) ADD_PATTERN(Share_Data) ADD_PATTERN(Swish) ADD_PATTERN(Log) @@ -2230,6 +2255,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); diff --git a/python/paddle/tensorrt/impls/input.py b/python/paddle/tensorrt/impls/input.py index 8098a9d1264612..5f2796c5e9b6df 100644 --- a/python/paddle/tensorrt/impls/input.py +++ b/python/paddle/tensorrt/impls/input.py @@ -60,3 +60,10 @@ def one_hot_converter(network, paddle_op, inputs): output_tensor = one_hot_layer.get_output(0) return [output_tensor] + + +@converter_registry.register("pd_op.embedding", trt_version="8.x") +def embedding_converter(network, paddle_op, inputs): + x_tensor, weight = inputs + layer = network.add_gather(weight, x_tensor, 0) + return layer.get_output(0) diff --git a/python/paddle/tensorrt/impls/math.py b/python/paddle/tensorrt/impls/math.py index 0177a301ec8f86..b098923627b925 100644 --- a/python/paddle/tensorrt/impls/math.py +++ b/python/paddle/tensorrt/impls/math.py @@ -25,6 +25,7 @@ fill_constant_layer, get_axes_for_reduce_op, trt_cast, + trt_equal, trt_expand, trt_max, ) @@ -305,3 +306,13 @@ def maximum_converter(network, paddle_op, inputs): network, paddle_op, inputs, trt.ElementWiseOperation.MAX ) return max_layer + + +@converter_registry.register("pd_op.isnan", trt_version="8.x") +def isnan_converter(network, paddle_op, inputs): + input_tensor = inputs[0] + equal_tensor = trt_equal(network, input_tensor, input_tensor) + layer = network.add_unary(equal_tensor, trt.UnaryOperation.NOT) + cast_layer = network.add_identity(layer.get_output(0)) + cast_layer.set_output_type(0, trt.bool) + return cast_layer.get_output(0) diff --git a/test/tensorrt/test_converter_input.py b/test/tensorrt/test_converter_input.py index 945ff2133efd1b..50d8ba53572158 100644 --- a/test/tensorrt/test_converter_input.py +++ b/test/tensorrt/test_converter_input.py @@ -64,5 +64,49 @@ def test_trt_result(self): self.check_trt_result() +def embedding_warpper_func(x): + layer = paddle.nn.Embedding(64, 4) + return layer(x) + + +class TestEmbeddingCase1TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.embedding + self.api_args = { + "x": np.array([[3, 16, 24], [6, 4, 47]]).astype("int64"), + "weight": np.random.uniform(-1, 1, [64, 4]).astype('float32'), + } + self.program_config = {"feed_list": ["x", "weight"]} + self.dynamic_shape_data = { + "x": lambda shape: np.random.randint(1, 64, size=shape).astype( + np.int64 + ), + } + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestEmbeddingCase2TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = embedding_warpper_func + self.api_args = { + "x": np.array([[3, 16, 24], [6, 4, 47]]).astype("int64"), + } + self.dynamic_shape_data = { + "x": lambda shape: np.random.randint(1, 64, size=shape).astype( + np.int64 + ), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + if __name__ == '__main__': unittest.main() diff --git a/test/tensorrt/test_converter_math.py b/test/tensorrt/test_converter_math.py index b6bb62f2f2a66c..327a84c4ee3869 100644 --- a/test/tensorrt/test_converter_math.py +++ b/test/tensorrt/test_converter_math.py @@ -450,5 +450,33 @@ def test_trt_result(self): self.check_trt_result() +class TestIsnanFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.isnan + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestIsnanIntTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.isnan + self.api_args = { + "x": np.random.randn(2, 3).astype("int64"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + if __name__ == '__main__': unittest.main()