Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ooooo-create committed Dec 23, 2024
1 parent a596e3f commit 258e8b4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
13 changes: 7 additions & 6 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1421,12 +1421,13 @@ class EmbeddingOpPattern
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
if (pir::GetDefiningOpForInput(op, 1)->name() == "builtin.parameter") {
// trt.Weights don't have the shape info.
VLOG(3) << "Skip to convert into TRT while found weight is a parameter "
"in pd_op.embedding.";
return false;
}
// if (pir::GetDefiningOpForInput(op, 1)->name() == "builtin.parameter") {
// // trt.Weights don't have the shape info.
// VLOG(3) << "Skip to convert into TRT while found weight is a parameter
// "
// "in pd_op.embedding.";
// return false;
// }
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
Expand Down
4 changes: 3 additions & 1 deletion test/tensorrt/test_converter_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def setUp(self):
),
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}

def test_trt_result(self):
self.check_marker(expected_result=False)
self.check_trt_result()


if __name__ == '__main__':
Expand Down

0 comments on commit 258e8b4

Please sign in to comment.