This repository has been archived by the owner on Jan 19, 2025. It is now read-only.
forked from NN-complr-tech/llvm
-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Alexseev Danila. Lab 4. Opt 2. (#162)
* Implemented fma pass which merge add and mul * fix * fix test * lil fix * Update mlir/test/Transforms/lab4/alexseev_danila/test.mlir --------- Co-authored-by: Alexseev Danila <user@DESKTOP-9FK8SDU.localdomain>
- Loading branch information
Showing
3 changed files
with
102 additions
and
0 deletions.
There are no files selected for viewing
50 changes: 50 additions & 0 deletions
50
mlir/lib/Transforms/lab4/alexseev_danila/AlexseevMulAddMergePass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Tools/Plugins/PassPlugin.h" | ||
|
||
class AlexseevMulAddMergePass | ||
: public mlir::PassWrapper<AlexseevMulAddMergePass, | ||
mlir::OperationPass<mlir::ModuleOp>> { | ||
void createFMAOperation(mlir::LLVM::FAddOp &addOp, mlir::LLVM::FMulOp &mulOp, | ||
mlir::Value &otherOperand) { | ||
mlir::OpBuilder builder(addOp); | ||
mlir::Value fma = builder.create<mlir::LLVM::FMAOp>( | ||
addOp.getLoc(), addOp.getType(), mulOp.getOperand(0), | ||
mulOp.getOperand(1), otherOperand); | ||
addOp.replaceAllUsesWith(fma); | ||
addOp.erase(); | ||
if (mulOp.use_empty()) | ||
mulOp.erase(); | ||
} | ||
|
||
public: | ||
mlir::StringRef getArgument() const final { return "alexseev_mul_add_merge"; } | ||
|
||
mlir::StringRef getDescription() const final { | ||
return "Merge multiplication and addition into a single math.fma"; | ||
} | ||
|
||
void runOnOperation() override { | ||
getOperation().walk([&](mlir::LLVM::FAddOp addOp) { | ||
mlir::Value addLHS = addOp.getOperand(0); | ||
mlir::Value addRHS = addOp.getOperand(1); | ||
|
||
if (auto mulOpLHS = addLHS.getDefiningOp<mlir::LLVM::FMulOp>()) { | ||
createFMAOperation(addOp, mulOpLHS, addRHS); | ||
} else if (auto mulOpRHS = addRHS.getDefiningOp<mlir::LLVM::FMulOp>()) { | ||
createFMAOperation(addOp, mulOpRHS, addLHS); | ||
} | ||
}); | ||
} | ||
}; | ||
|
||
MLIR_DECLARE_EXPLICIT_TYPE_ID(AlexseevMulAddMergePass) | ||
MLIR_DEFINE_EXPLICIT_TYPE_ID(AlexseevMulAddMergePass) | ||
|
||
extern "C" LLVM_ATTRIBUTE_WEAK mlir::PassPluginLibraryInfo | ||
mlirGetPassPluginInfo() { | ||
return {MLIR_PLUGIN_API_VERSION, "alexseev_mul_add_merge", | ||
LLVM_VERSION_STRING, | ||
[]() { mlir::PassRegistration<AlexseevMulAddMergePass>(); }}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
set(PluginName AlexseevMulAddMergePass) | ||
|
||
file(GLOB_RECURSE ALL_SOURCE_FILES *.cpp *.h) | ||
add_llvm_pass_plugin(${PluginName} | ||
${ALL_SOURCE_FILES} | ||
DEPENDS | ||
intrinsics_gen | ||
MLIRBuiltinLocationAttributesIncGen | ||
BUILDTREE_ONLY | ||
) | ||
|
||
set(MLIR_TEST_DEPENDS ${PluginName} ${MLIR_TEST_DEPENDS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// RUN: mlir-opt -load-pass-plugin=%mlir_lib_dir/AlexseevMulAddMergePass%shlibext --pass-pipeline="builtin.module(alexseev_mul_add_merge)" %s | FileCheck %s | ||
|
||
module { | ||
// double c = a * 6.0 + b; | ||
llvm.func local_unnamed_addr @foo1(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} { | ||
%0 = llvm.mlir.constant(6.000000e+00 : f64) : f64 | ||
// CHECK-NOT: %1 = llvm.fmul %arg0, %0 : f64 | ||
// CHECK-NOT: %2 = llvm.fadd %1, %arg1 : f64 | ||
// CHECK: %1 = llvm.intr.fma(%arg0, %0, %arg1) : (f64, f64, f64) -> f64 | ||
%1 = llvm.fmul %arg0, %0 : f64 | ||
%2 = llvm.fadd %1, %arg1 : f64 | ||
llvm.return | ||
} | ||
|
||
// double c = a + 10.0 * b; | ||
llvm.func local_unnamed_addr @foo2(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} { | ||
%0 = llvm.mlir.constant(10.000000e+00 : f64) : f64 | ||
// CHECK-NOT: %1 = llvm.fmul %arg1, %0 : f64 | ||
// CHECK-NOT: %2 = llvm.fadd %1, %arg0 : f64 | ||
// CHECK: %1 = llvm.intr.fma(%arg1, %0, %arg0) : (f64, f64, f64) -> f64 | ||
%1 = llvm.fmul %arg1, %0 : f64 | ||
%2 = llvm.fadd %1, %arg0 : f64 | ||
llvm.return | ||
} | ||
|
||
// double c = a * b; | ||
// double d = c + e; | ||
// double f = c + x; | ||
llvm.func local_unnamed_addr @foo3(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}, %arg3: f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} { | ||
// CHECK-NOT: %0 = llvm.fmul %arg0, %arg1 : f64 | ||
// CHECK-NOT: %1 = llvm.fadd %0, %arg2 : f64 | ||
// CHECK: %0 = llvm.intr.fma(%arg0, %arg1, %arg2) : (f64, f64, f64) -> f64 | ||
%0 = llvm.fmul %arg0, %arg1 : f64 | ||
%1 = llvm.fadd %0, %arg2 : f64 | ||
// CHECK-NOT: %2 = llvm.fadd %0, %arg3 : f64 | ||
// CHECK: %1 = llvm.intr.fma(%arg0, %arg1, %arg3) : (f64, f64, f64) -> f64 | ||
%2 = llvm.fadd %0, %arg3 : f64 | ||
llvm.return | ||
} | ||
} |