diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc index 1b81b3a5672117..4141357e0a209e 100644 --- a/paddle/fluid/pir/drr/api/tensor_interface.cc +++ b/paddle/fluid/pir/drr/api/tensor_interface.cc @@ -30,5 +30,7 @@ bool DtypeInterface::operator==(const DtypeInterface& other) const { return *dtype_ == *other.dtype_; } +IrDtype DtypeInterface::dtype() const { return *(this->dtype_); } + } // namespace drr } // namespace pir diff --git a/paddle/fluid/pir/drr/api/tensor_interface.h b/paddle/fluid/pir/drr/api/tensor_interface.h index 7629857591bf33..2ae036cb9bc958 100644 --- a/paddle/fluid/pir/drr/api/tensor_interface.h +++ b/paddle/fluid/pir/drr/api/tensor_interface.h @@ -42,6 +42,7 @@ class ShapeInterface final { class DtypeInterface final { public: bool operator==(const DtypeInterface& other) const; + IrDtype dtype() const; private: explicit DtypeInterface(const IrDtype* dtype) : dtype_(dtype) {} diff --git a/paddle/fluid/pir/drr/ir_value.h b/paddle/fluid/pir/drr/ir_value.h index 907df9dfd24ebc..7807ce7f3d9bae 100644 --- a/paddle/fluid/pir/drr/ir_value.h +++ b/paddle/fluid/pir/drr/ir_value.h @@ -44,6 +44,11 @@ class IrDtype { bool operator==(IrDtype other) const { return dtype_ == other.dtype_; } + template + bool isa() const { + return dtype_.isa(); + } + private: const pir::Type dtype_; }; diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index 823c7bdc8f81b7..68619b819bf684 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/api/drr_pattern_base.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/place.h" @@ -22,11 +23,14 @@ namespace { -inline int getSMVersion() { +int getSMVersion() { int sm_version = 80; #if defined(PADDLE_WITH_CUDA) sm_version = paddle::platform::GetGPUComputeCapability( paddle::platform::GetCurrentDeviceId()); +#else + PADDLE_THROW(paddle::platform::errors::Unavailable( + "fused_weight_only_linear_pass needs paddle compiled with CUDA.")); #endif return sm_version; } @@ -40,12 +44,14 @@ class FusedWeightOnlyLinearPattern // pir::drr::SourcePattern src = ctx->SourcePattern(); const auto &matmul = - src.Op("pd_op.matmul", + src.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", src.Attr("matmul_transpose_x")}, {"transpose_y", src.Attr("matmul_transpose_y")}}); + const auto ¶meter = src.Op( + pir::ParameterOp::name(), {{"parameter_name", src.Attr("param_name")}}); + src.Tensor("w") = parameter(); src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w")); - - const auto &add = src.Op("pd_op.add"); + const auto &add = src.Op(paddle::dialect::AddOp::name()); src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias")); // @@ -62,6 +68,17 @@ class FusedWeightOnlyLinearPattern return false; } + auto w_dims = match_ctx.Tensor("w").Shape(); + if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; + + auto w_dtype = match_ctx.Tensor("w").Dtype(); + if (!w_dtype.dtype().isa() && + !w_dtype.dtype().isa()) + return false; + + auto x_dims = match_ctx.Tensor("x").Shape(); + if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false; + return true; }); // @@ -74,15 +91,15 @@ class FusedWeightOnlyLinearPattern res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { return "weight_only_int8"; }); - // int arch = getSMVersion(); - const auto &weight_quantize_arch_attr = - res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any { - return 80; + + const auto &arch_attr = + res.Attr([&](const pir::drr::MatchContext &match_ctx) -> int { + return getSMVersion(); }); - const auto &weight_quantize = res.Op( - "pd_op.weight_quantize", - {{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}}); + const auto &weight_quantize = + res.Op(paddle::dialect::WeightQuantizeOp::name(), + {{"algo", weight_only_int8_attr}, {"arch", arch_attr}}); weight_quantize({&res.Tensor("w")}, {&res.Tensor("quanted_weight_tensor"), &res.Tensor("weight_scale_tensor")}); @@ -92,12 +109,9 @@ class FusedWeightOnlyLinearPattern return "int8"; }); - const auto &weight_only_linear_arch_attr = res.Attr( - [&](const pir::drr::MatchContext &match_ctx) -> int { return 80; }); const auto &weight_only_linear = - res.Op("pd_op.weight_only_linear", - {{"weight_dtype", weight_dtype_attr}, - {"arch", weight_only_linear_arch_attr}}); + res.Op(paddle::dialect::WeightOnlyLinearOp::name(), + {{"weight_dtype", weight_dtype_attr}, {"arch", arch_attr}}); weight_only_linear({&res.Tensor("x"), &res.Tensor("quanted_weight_tensor"), &res.Tensor("bias"), @@ -119,8 +133,8 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { bool CanApplyOn(pir::Operation *op) const override { int sm_vesion = getSMVersion(); - if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 && - sm_vesion != 75) { + if (sm_vesion != 70 && sm_vesion != 75 && sm_vesion != 80 && + sm_vesion != 86) { return false; } return op->num_regions() > 0; diff --git a/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py b/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py index 1ccfe61a9c13bb..5cd0ff6898039b 100644 --- a/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py +++ b/test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir.core import create_parameter np.random.seed(2013) @@ -56,16 +57,18 @@ def build_ir_progam(self): x = paddle.static.data( name='x', shape=[3, 64, 64], dtype=self.dtype ) - w = paddle.static.data( - name="w", shape=[64, 64], dtype=self.dtype + initializer = paddle.nn.initializer.Constant(0.0) + w = create_parameter( + shape=[64, 64], dtype=self.dtype, initializer=initializer ) bias_ = paddle.static.data( - name="bias", shape=[64], dtype=self.dtype + name="bias", + shape=[64], + dtype=self.dtype, ) bias = paddle.assign(bias_) res1 = paddle.matmul(x=x, y=w) out = paddle.add(res1, bias) - self.pass_list = ['fused_weight_only_linear_pass'] self.feeds = { "x": np.random.random((3, 64, 64)).astype(self.dtype), @@ -73,17 +76,19 @@ def build_ir_progam(self): "bias": np.random.random(64).astype(self.dtype), } self.fetch_list = [out] - self.valid_op_map = { - "pd_op.weight_only_linear": 1, - "pd_op.weight_quantize": 1, - "pd_op.matmul": 0, - "pd_op.add": 0, - } + return pir_program def setUp(self): self.place_runtime = "gpu" self.dtype = 'float32' + # weight_quantize need weight's dtype to be fp16 or bf16 + self.valid_op_map = { + "pd_op.weight_only_linear": 0, + "pd_op.weight_quantize": 0, + "pd_op.matmul": 1, + "pd_op.add": 1, + } def sample_program(self): yield self.build_ir_progam(), False @@ -96,6 +101,56 @@ class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32): def setUp(self): self.place_runtime = "gpu" self.dtype = 'float16' + self.valid_op_map = { + "pd_op.weight_only_linear": 1, + "pd_op.weight_quantize": 1, + "pd_op.matmul": 0, + "pd_op.add": 0, + } + + +class TestFusedWeightOnlyLinearPass_wdim_divisible_by_16( + TestFusedWeightOnlyLinearPass_Fp32 +): + def build_ir_progam(self): + pir_program = None + with paddle.pir_utils.IrGuard(): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): + x = paddle.static.data( + name='x', shape=[3, 64, 64], dtype=self.dtype + ) + initializer = paddle.nn.initializer.Constant(0.0) + w = create_parameter( + shape=[64, 15], dtype=self.dtype, initializer=initializer + ) + bias_ = paddle.static.data( + name="bias", + shape=[15], + dtype=self.dtype, + ) + bias = paddle.assign(bias_) + res1 = paddle.matmul(x=x, y=w) + out = paddle.add(res1, bias) + self.pass_list = ['fused_weight_only_linear_pass'] + self.feeds = { + "x": np.random.random((3, 64, 64)).astype(self.dtype), + "w": np.random.random((64, 15)).astype(self.dtype), + "bias": np.random.random(15).astype(self.dtype), + } + self.fetch_list = [out] + + return pir_program + + def setUp(self): + self.place_runtime = "gpu" + self.dtype = 'float16' + self.valid_op_map = { + "pd_op.weight_only_linear": 0, + "pd_op.weight_quantize": 0, + "pd_op.matmul": 1, + "pd_op.add": 1, + } if __name__ == "__main__":