From 9bae5d3eae54a29b12d3ddfe597e46a03795661a Mon Sep 17 00:00:00 2001 From: SSuren4ik <114421510+SSuren4ik@users.noreply.github.com> Date: Tue, 21 May 2024 18:56:16 +0300 Subject: [PATCH] Suren Simonyan Lab4 Var2 (#175) FMA pass on MLIR --- .../lab4/simonyan_suren/CMakeLists.txt | 12 ++++ .../simonyan_suren/SimonyanSurenFMAPass.cpp | 61 ++++++++++++++++++ .../lab4/simonyan_suren/test_simonyan.mlir | 63 +++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 mlir/lib/Transforms/lab4/simonyan_suren/CMakeLists.txt create mode 100644 mlir/lib/Transforms/lab4/simonyan_suren/SimonyanSurenFMAPass.cpp create mode 100644 mlir/test/Transforms/lab4/simonyan_suren/test_simonyan.mlir diff --git a/mlir/lib/Transforms/lab4/simonyan_suren/CMakeLists.txt b/mlir/lib/Transforms/lab4/simonyan_suren/CMakeLists.txt new file mode 100644 index 0000000000000..5c16808b55718 --- /dev/null +++ b/mlir/lib/Transforms/lab4/simonyan_suren/CMakeLists.txt @@ -0,0 +1,12 @@ +set(PluginName SimonyanSurenFMAPass) + +file(GLOB_RECURSE ALL_SOURCE_FILES *.cpp *.h) +add_llvm_pass_plugin(${PluginName} + ${ALL_SOURCE_FILES} + DEPENDS + intrinsics_gen + MLIRBuiltinLocationAttributesIncGen + BUILDTREE_ONLY + ) + +set(MLIR_TEST_DEPENDS ${PluginName} ${MLIR_TEST_DEPENDS} PARENT_SCOPE) diff --git a/mlir/lib/Transforms/lab4/simonyan_suren/SimonyanSurenFMAPass.cpp b/mlir/lib/Transforms/lab4/simonyan_suren/SimonyanSurenFMAPass.cpp new file mode 100644 index 0000000000000..028f67b92322f --- /dev/null +++ b/mlir/lib/Transforms/lab4/simonyan_suren/SimonyanSurenFMAPass.cpp @@ -0,0 +1,61 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Tools/Plugins/PassPlugin.h" + +using namespace mlir; + +namespace { +class SimonyanSurenFMAPass + : public PassWrapper> { +private: + void handleMulOp(LLVM::FAddOp &addOp, LLVM::FMulOp &mulOp, + Value &otherOperand) { + OpBuilder builder(addOp); + Value fma = builder.create(addOp.getLoc(), mulOp.getOperand(0), + mulOp.getOperand(1), otherOperand); + addOp.replaceAllUsesWith(fma); + addOp.erase(); + } + +public: + void runOnOperation() override { + ModuleOp module = getOperation(); + // Add operation. + module.walk([](LLVM::FAddOp addOp) { + Value addLHS = addOp.getOperand(0); + Value addRHS = addOp.getOperand(1); + if (auto mulOpLHS = addLHS.getDefiningOp()) { + handleMulOp(addOp, mulOpLHS, addRHS); + } else if (auto mulOpRHS = addRHS.getDefiningOp()) { + handleMulOp(addOp, mulOpRHS, addLHS); + } + }); + + // Mul operation. + module.walk([](LLVM::FMulOp mulOp) { + if (mulOp.use_empty()) { + mulOp.erase(); + } + }); + } + + StringRef getArgument() const final { return "simonyan_suren_fma"; } + StringRef getDescription() const final { + return "Replaces add and multiply operations with a single instruction."; + } +}; + +} // namespace + +MLIR_DECLARE_EXPLICIT_TYPE_ID(SimonyanSurenFMAPass) +MLIR_DEFINE_EXPLICIT_TYPE_ID(SimonyanSurenFMAPass) + +PassPluginLibraryInfo getSimonyanSurenFMAPassPluginInfo() { + return {MLIR_PLUGIN_API_VERSION, "simonyan_suren_fma", LLVM_VERSION_STRING, + []() { PassRegistration(); }}; +} + +extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo mlirGetPassPluginInfo() { + return getSimonyanSurenFMAPassPluginInfo(); +} \ No newline at end of file diff --git a/mlir/test/Transforms/lab4/simonyan_suren/test_simonyan.mlir b/mlir/test/Transforms/lab4/simonyan_suren/test_simonyan.mlir new file mode 100644 index 0000000000000..3a9aa57bf1dc1 --- /dev/null +++ b/mlir/test/Transforms/lab4/simonyan_suren/test_simonyan.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt -load-pass-plugin=%mlir_lib_dir/SimonyanSurenFMAPass%shlibext --pass-pipeline="builtin.module(simonyan_suren_fma)" %s | FileCheck %s + +// double func1(double a, double b) { + // double constant1 = 2.0; + // double constant2 = 5.0; + // double result = a * constant1 + b * constant2; + // return result; +// } + +// double func2(double a, double b) { + // double constant1 = 5.0; + // double constant2 = 2.0; + // double mul = a * b; + // double add1 = mul + constant1; + // double mul2 = mul * constant2; + // double result = mul2 + add1; + // return result; +// } + +// double func4(double a, double b) { + // double constant1 = 2.0; + // double constant2 = -7.0; + // double mul = a * constant1; + // double result = mul + constant2; + // return result; +// } + +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<4xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry, dense<32> : vector<4xi32>>, #dlti.dl_entry, dense<64> : vector<4xi32>>, #dlti.dl_entry, dense<32> : vector<4xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry<"dlti.stack_alignment", 128 : i32>, #dlti.dl_entry<"dlti.endianness", "little">>} { + llvm.func local_unnamed_addr @_Z5func1dd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} { + %0 = llvm.mlir.constant(2.000000e+00 : f64) : f64 + %1 = llvm.mlir.constant(5.000000e+00 : f64) : f64 + %2 = llvm.fmul %arg0, %0 : f64 + %3 = llvm.fmul %arg1, %1 : f64 + %4 = llvm.fadd %2, %3 : f64 + llvm.return %4 : f64 + // CHECK: %2 = llvm.fmul %arg1, %1 : f64 + // CHECK-NEXT: %3 = llvm.intr.fma(%arg0, %0, %2) : (f64, f64, f64) -> f64 + // CHECK-NEXT: llvm.return %3 : f64 + } + + llvm.func local_unnamed_addr @_Z5func2dd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} { + %0 = llvm.mlir.constant(5.000000e+00 : f64) : f64 + %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64 + %2 = llvm.fmul %arg0, %arg1 : f64 + %3 = llvm.fadd %2, %0 : f64 + %4 = llvm.fmul %2, %1 : f64 + %5 = llvm.fadd %4, %3 : f64 + llvm.return %5 : f64 + // CHECK: %3 = llvm.intr.fma(%arg0, %arg1, %0) : (f64, f64, f64) -> f64 + // CHEK-NEXT: %4 = llvm.intr.fma(%2, %1, %3) : (f64, f64, f64) -> f64 + // CHEK-NEXT: llvm.return %5 : f64 + } + + llvm.func local_unnamed_addr @_Z5func3dd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} { + %0 = llvm.mlir.constant(2.000000e+00 : f64) : f64 + %1 = llvm.mlir.constant(-7.000000e+00 : f64) : f64 + %2 = llvm.fmul %arg0, %0 : f64 + %3 = llvm.fadd %2, %1 : f64 + llvm.return %3 : f64 + // CHECK: %2 = llvm.intr.fma(%arg0, %0, %1) : (f64, f64, f64) -> f64 + // CHECK-NEXT: llvm.return %2 : f64 + } +} \ No newline at end of file