diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 08e0d058bc0090..a10ea3dbc1d4c5 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1421,12 +1421,13 @@ class EmbeddingOpPattern op.attribute(kCanRunTrtAttr).data()) { return false; } - if (pir::GetDefiningOpForInput(op, 1)->name() == "builtin.parameter") { - // trt.Weights don't have the shape info. - VLOG(3) << "Skip to convert into TRT while found weight is a parameter " - "in pd_op.embedding."; - return false; - } + // if (pir::GetDefiningOpForInput(op, 1)->name() == "builtin.parameter") { + // // trt.Weights don't have the shape info. + // VLOG(3) << "Skip to convert into TRT while found weight is a parameter + // " + // "in pd_op.embedding."; + // return false; + // } op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); return true; } diff --git a/test/tensorrt/test_converter_input.py b/test/tensorrt/test_converter_input.py index 5a1e8a617c301d..50d8ba53572158 100644 --- a/test/tensorrt/test_converter_input.py +++ b/test/tensorrt/test_converter_input.py @@ -101,9 +101,11 @@ def setUp(self): ), } self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} def test_trt_result(self): - self.check_marker(expected_result=False) + self.check_trt_result() if __name__ == '__main__':