Skip to content

Commit

Permalink
[PIR & Inference] Refactor fused_weight_only_linear_pass (#59792)
Browse files Browse the repository at this point in the history
* refactor: refactor fused_weight_only_linear_pass

* refactor: add else case for PADDLE_WITH_CUDA

* fix: fix typo

* refactor: refactor pass

* refactor: support sm 70, 75, 80 and 86 in pass

* refactor: refactor pass and test

* fix: fix typo

* refactor: use xxOp::name() instead of pd_op.xx in pass

* refactor: refactor error msg and fix typo

* refactor: refactor pass and test

* fix: fix typo

* refactor: refactor IrDtype
  • Loading branch information
Wanglongzhi2001 authored Dec 12, 2023
1 parent d9407a5 commit 2360fae
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 28 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/pir/drr/api/tensor_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@ bool DtypeInterface::operator==(const DtypeInterface& other) const {
return *dtype_ == *other.dtype_;
}

IrDtype DtypeInterface::dtype() const { return *(this->dtype_); }

} // namespace drr
} // namespace pir
1 change: 1 addition & 0 deletions paddle/fluid/pir/drr/api/tensor_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ShapeInterface final {
class DtypeInterface final {
public:
bool operator==(const DtypeInterface& other) const;
IrDtype dtype() const;

private:
explicit DtypeInterface(const IrDtype* dtype) : dtype_(dtype) {}
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pir/drr/ir_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class IrDtype {

bool operator==(IrDtype other) const { return dtype_ == other.dtype_; }

template <typename T>
bool isa() const {
return dtype_.isa<T>();
}

private:
const pir::Type dtype_;
};
Expand Down
50 changes: 32 additions & 18 deletions paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h"
Expand All @@ -22,11 +23,14 @@

namespace {

inline int getSMVersion() {
int getSMVersion() {
int sm_version = 80;
#if defined(PADDLE_WITH_CUDA)
sm_version = paddle::platform::GetGPUComputeCapability(
paddle::platform::GetCurrentDeviceId());
#else
PADDLE_THROW(paddle::platform::errors::Unavailable(
"fused_weight_only_linear_pass needs paddle compiled with CUDA."));
#endif
return sm_version;
}
Expand All @@ -40,12 +44,14 @@ class FusedWeightOnlyLinearPattern
//
pir::drr::SourcePattern src = ctx->SourcePattern();
const auto &matmul =
src.Op("pd_op.matmul",
src.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", src.Attr("matmul_transpose_x")},
{"transpose_y", src.Attr("matmul_transpose_y")}});
const auto &parameter = src.Op(
pir::ParameterOp::name(), {{"parameter_name", src.Attr("param_name")}});
src.Tensor("w") = parameter();
src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w"));

const auto &add = src.Op("pd_op.add");
const auto &add = src.Op(paddle::dialect::AddOp::name());
src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias"));

//
Expand All @@ -62,6 +68,17 @@ class FusedWeightOnlyLinearPattern
return false;
}

auto w_dims = match_ctx.Tensor("w").Shape();
if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false;

auto w_dtype = match_ctx.Tensor("w").Dtype();
if (!w_dtype.dtype().isa<pir::Float16Type>() &&
!w_dtype.dtype().isa<pir::BFloat16Type>())
return false;

auto x_dims = match_ctx.Tensor("x").Shape();
if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false;

return true;
});
//
Expand All @@ -74,15 +91,15 @@ class FusedWeightOnlyLinearPattern
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "weight_only_int8";
});
// int arch = getSMVersion();
const auto &weight_quantize_arch_attr =
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any {
return 80;

const auto &arch_attr =
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> int {
return getSMVersion();
});

const auto &weight_quantize = res.Op(
"pd_op.weight_quantize",
{{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}});
const auto &weight_quantize =
res.Op(paddle::dialect::WeightQuantizeOp::name(),
{{"algo", weight_only_int8_attr}, {"arch", arch_attr}});
weight_quantize({&res.Tensor("w")},
{&res.Tensor("quanted_weight_tensor"),
&res.Tensor("weight_scale_tensor")});
Expand All @@ -92,12 +109,9 @@ class FusedWeightOnlyLinearPattern
return "int8";
});

const auto &weight_only_linear_arch_attr = res.Attr(
[&](const pir::drr::MatchContext &match_ctx) -> int { return 80; });
const auto &weight_only_linear =
res.Op("pd_op.weight_only_linear",
{{"weight_dtype", weight_dtype_attr},
{"arch", weight_only_linear_arch_attr}});
res.Op(paddle::dialect::WeightOnlyLinearOp::name(),
{{"weight_dtype", weight_dtype_attr}, {"arch", arch_attr}});
weight_only_linear({&res.Tensor("x"),
&res.Tensor("quanted_weight_tensor"),
&res.Tensor("bias"),
Expand All @@ -119,8 +133,8 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {

bool CanApplyOn(pir::Operation *op) const override {
int sm_vesion = getSMVersion();
if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 &&
sm_vesion != 75) {
if (sm_vesion != 70 && sm_vesion != 75 && sm_vesion != 80 &&
sm_vesion != 86) {
return false;
}
return op->num_regions() > 0;
Expand Down
75 changes: 65 additions & 10 deletions test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir.core import create_parameter

np.random.seed(2013)

Expand Down Expand Up @@ -56,34 +57,38 @@ def build_ir_progam(self):
x = paddle.static.data(
name='x', shape=[3, 64, 64], dtype=self.dtype
)
w = paddle.static.data(
name="w", shape=[64, 64], dtype=self.dtype
initializer = paddle.nn.initializer.Constant(0.0)
w = create_parameter(
shape=[64, 64], dtype=self.dtype, initializer=initializer
)
bias_ = paddle.static.data(
name="bias", shape=[64], dtype=self.dtype
name="bias",
shape=[64],
dtype=self.dtype,
)
bias = paddle.assign(bias_)
res1 = paddle.matmul(x=x, y=w)
out = paddle.add(res1, bias)

self.pass_list = ['fused_weight_only_linear_pass']
self.feeds = {
"x": np.random.random((3, 64, 64)).astype(self.dtype),
"w": np.random.random((64, 64)).astype(self.dtype),
"bias": np.random.random(64).astype(self.dtype),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.weight_only_linear": 1,
"pd_op.weight_quantize": 1,
"pd_op.matmul": 0,
"pd_op.add": 0,
}

return pir_program

def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float32'
# weight_quantize need weight's dtype to be fp16 or bf16
self.valid_op_map = {
"pd_op.weight_only_linear": 0,
"pd_op.weight_quantize": 0,
"pd_op.matmul": 1,
"pd_op.add": 1,
}

def sample_program(self):
yield self.build_ir_progam(), False
Expand All @@ -96,6 +101,56 @@ class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32):
def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float16'
self.valid_op_map = {
"pd_op.weight_only_linear": 1,
"pd_op.weight_quantize": 1,
"pd_op.matmul": 0,
"pd_op.add": 0,
}


class TestFusedWeightOnlyLinearPass_wdim_divisible_by_16(
TestFusedWeightOnlyLinearPass_Fp32
):
def build_ir_progam(self):
pir_program = None
with paddle.pir_utils.IrGuard():
pir_program = paddle.static.Program()
with paddle.pir.core.program_guard(pir_program):
x = paddle.static.data(
name='x', shape=[3, 64, 64], dtype=self.dtype
)
initializer = paddle.nn.initializer.Constant(0.0)
w = create_parameter(
shape=[64, 15], dtype=self.dtype, initializer=initializer
)
bias_ = paddle.static.data(
name="bias",
shape=[15],
dtype=self.dtype,
)
bias = paddle.assign(bias_)
res1 = paddle.matmul(x=x, y=w)
out = paddle.add(res1, bias)
self.pass_list = ['fused_weight_only_linear_pass']
self.feeds = {
"x": np.random.random((3, 64, 64)).astype(self.dtype),
"w": np.random.random((64, 15)).astype(self.dtype),
"bias": np.random.random(15).astype(self.dtype),
}
self.fetch_list = [out]

return pir_program

def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float16'
self.valid_op_map = {
"pd_op.weight_only_linear": 0,
"pd_op.weight_quantize": 0,
"pd_op.matmul": 1,
"pd_op.add": 1,
}


if __name__ == "__main__":
Expand Down

0 comments on commit 2360fae

Please sign in to comment.