Skip to content
This repository has been archived by the owner on Jan 19, 2025. It is now read-only.

Commit

Permalink
Alexseev Danila. Lab 4. Opt 2. (#162)
Browse files Browse the repository at this point in the history
* Implemented fma pass which merge add and mul

* fix

* fix test

* lil fix

* Update mlir/test/Transforms/lab4/alexseev_danila/test.mlir

---------

Co-authored-by: Alexseev Danila <user@DESKTOP-9FK8SDU.localdomain>
  • Loading branch information
lxvdnl and Alexseev Danila authored May 19, 2024
1 parent 9042d43 commit e8113fa
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Tools/Plugins/PassPlugin.h"

class AlexseevMulAddMergePass
: public mlir::PassWrapper<AlexseevMulAddMergePass,
mlir::OperationPass<mlir::ModuleOp>> {
void createFMAOperation(mlir::LLVM::FAddOp &addOp, mlir::LLVM::FMulOp &mulOp,
mlir::Value &otherOperand) {
mlir::OpBuilder builder(addOp);
mlir::Value fma = builder.create<mlir::LLVM::FMAOp>(
addOp.getLoc(), addOp.getType(), mulOp.getOperand(0),
mulOp.getOperand(1), otherOperand);
addOp.replaceAllUsesWith(fma);
addOp.erase();
if (mulOp.use_empty())
mulOp.erase();
}

public:
mlir::StringRef getArgument() const final { return "alexseev_mul_add_merge"; }

mlir::StringRef getDescription() const final {
return "Merge multiplication and addition into a single math.fma";
}

void runOnOperation() override {
getOperation().walk([&](mlir::LLVM::FAddOp addOp) {
mlir::Value addLHS = addOp.getOperand(0);
mlir::Value addRHS = addOp.getOperand(1);

if (auto mulOpLHS = addLHS.getDefiningOp<mlir::LLVM::FMulOp>()) {
createFMAOperation(addOp, mulOpLHS, addRHS);
} else if (auto mulOpRHS = addRHS.getDefiningOp<mlir::LLVM::FMulOp>()) {
createFMAOperation(addOp, mulOpRHS, addLHS);
}
});
}
};

MLIR_DECLARE_EXPLICIT_TYPE_ID(AlexseevMulAddMergePass)
MLIR_DEFINE_EXPLICIT_TYPE_ID(AlexseevMulAddMergePass)

extern "C" LLVM_ATTRIBUTE_WEAK mlir::PassPluginLibraryInfo
mlirGetPassPluginInfo() {
return {MLIR_PLUGIN_API_VERSION, "alexseev_mul_add_merge",
LLVM_VERSION_STRING,
[]() { mlir::PassRegistration<AlexseevMulAddMergePass>(); }};
}
12 changes: 12 additions & 0 deletions mlir/lib/Transforms/lab4/alexseev_danila/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
set(PluginName AlexseevMulAddMergePass)

file(GLOB_RECURSE ALL_SOURCE_FILES *.cpp *.h)
add_llvm_pass_plugin(${PluginName}
${ALL_SOURCE_FILES}
DEPENDS
intrinsics_gen
MLIRBuiltinLocationAttributesIncGen
BUILDTREE_ONLY
)

set(MLIR_TEST_DEPENDS ${PluginName} ${MLIR_TEST_DEPENDS} PARENT_SCOPE)
40 changes: 40 additions & 0 deletions mlir/test/Transforms/lab4/alexseev_danila/test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: mlir-opt -load-pass-plugin=%mlir_lib_dir/AlexseevMulAddMergePass%shlibext --pass-pipeline="builtin.module(alexseev_mul_add_merge)" %s | FileCheck %s

module {
// double c = a * 6.0 + b;
llvm.func local_unnamed_addr @foo1(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.mlir.constant(6.000000e+00 : f64) : f64
// CHECK-NOT: %1 = llvm.fmul %arg0, %0 : f64
// CHECK-NOT: %2 = llvm.fadd %1, %arg1 : f64
// CHECK: %1 = llvm.intr.fma(%arg0, %0, %arg1) : (f64, f64, f64) -> f64
%1 = llvm.fmul %arg0, %0 : f64
%2 = llvm.fadd %1, %arg1 : f64
llvm.return
}

// double c = a + 10.0 * b;
llvm.func local_unnamed_addr @foo2(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.mlir.constant(10.000000e+00 : f64) : f64
// CHECK-NOT: %1 = llvm.fmul %arg1, %0 : f64
// CHECK-NOT: %2 = llvm.fadd %1, %arg0 : f64
// CHECK: %1 = llvm.intr.fma(%arg1, %0, %arg0) : (f64, f64, f64) -> f64
%1 = llvm.fmul %arg1, %0 : f64
%2 = llvm.fadd %1, %arg0 : f64
llvm.return
}

// double c = a * b;
// double d = c + e;
// double f = c + x;
llvm.func local_unnamed_addr @foo3(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}, %arg3: f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
// CHECK-NOT: %0 = llvm.fmul %arg0, %arg1 : f64
// CHECK-NOT: %1 = llvm.fadd %0, %arg2 : f64
// CHECK: %0 = llvm.intr.fma(%arg0, %arg1, %arg2) : (f64, f64, f64) -> f64
%0 = llvm.fmul %arg0, %arg1 : f64
%1 = llvm.fadd %0, %arg2 : f64
// CHECK-NOT: %2 = llvm.fadd %0, %arg3 : f64
// CHECK: %1 = llvm.intr.fma(%arg0, %arg1, %arg3) : (f64, f64, f64) -> f64
%2 = llvm.fadd %0, %arg3 : f64
llvm.return
}
}

0 comments on commit e8113fa

Please sign in to comment.