Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【SCU】[Paddle TensorRT No.28] Add pd_op.logical_xor converter #69958

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,32 @@ class LessThanOpPattern
return true;
}
};

template <typename OpType>
class LogicalCommonOpPattern : public pir::OpRewritePattern<OpType> {
public:
using pir::OpRewritePattern<OpType>::OpRewritePattern;
bool MatchAndRewrite(OpType op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->template attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (!(x_dtype.isa<pir::BoolType>() && y_dtype.isa<pir::BoolType>())) {
VLOG(3) << "pd_op.logical_xor op only supports bool datatype";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以参考下pd_op.logical_or那个pr的写法,这三个共用一个class,你这里叫做pd_op.logical_xor不合适

return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
using LogicalXorOpPattern =
LogicalCommonOpPattern<paddle::dialect::LogicalXorOp>;

class MulticlassNms3OpPattern
: public pir::OpRewritePattern<paddle::dialect::MulticlassNms3Op> {
public:
Expand Down Expand Up @@ -2259,6 +2285,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<SetValueWithTensor_OpPattern>(context));
ps.Add(std::make_unique<EqualOpPattern>(context));
ps.Add(std::make_unique<NotEqualOpPattern>(context));
ps.Add(std::make_unique<LogicalXorOpPattern>(context));
ps.Add(std::make_unique<TanhOpPattern>(context));
ps.Add(std::make_unique<CeluOpPattern>(context));
ps.Add(std::make_unique<MishOpPattern>(context));
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensorrt/impls/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
"pd_op.greater_than": trt.ElementWiseOperation.GREATER,
"pd_op.less_than": trt.ElementWiseOperation.LESS,
"pd_op.equal": trt.ElementWiseOperation.EQUAL,
"pd_op.logical_xor": trt.ElementWiseOperation.XOR,
}


@converter_registry.register("pd_op.greater_than", trt_version="8.x")
@converter_registry.register("pd_op.less_than", trt_version="8.x")
@converter_registry.register("pd_op.equal", trt_version="8.x")
@converter_registry.register("pd_op.logical_xor", trt_version="8.x")
def logic_converter(network, paddle_op, inputs):
layer_output = add_elementwise_layer(
network, paddle_op, inputs, logic_type_map[paddle_op.name()]
Expand Down
39 changes: 39 additions & 0 deletions test/tensorrt/test_converter_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,44 @@ def test_trt_result(self):
self.check_trt_result()


class TestLogicalXorTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.logical_xor

def test_trt_result(self):
self.api_args = {
"x": np.random.choice([True, False], size=(3,)).astype("bool"),
"y": np.random.choice([True, False], size=(3,)).astype("bool"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1], "y": [1]}
self.max_shape = {"x": [5], "y": [5]}
self.check_trt_result()

def test_trt_diff_shape_result(self):
self.api_args = {
"x": np.random.choice([True, False], size=(2, 3)).astype("bool"),
"y": np.random.choice([True, False], size=(3)).astype("bool"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1, 3], "y": [3]}
self.max_shape = {"x": [4, 3], "y": [3]}
self.check_trt_result()


class TestLogicalXorMarker(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.logical_xor
self.api_args = {
"x": np.random.randn(3).astype("int64"),
"y": np.random.randn(3).astype("int64"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.target_marker_op = "pd_op.logical_xor"

def test_trt_result(self):
self.check_marker(expected_result=False)


if __name__ == '__main__':
unittest.main()