Skip to content
This repository has been archived by the owner on Jan 19, 2025. It is now read-only.

Savotina Valeria, Lab №4, var: 2 #161

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/lib/Transforms/lab4/savorina_valeria/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
set(PluginName SavotinaMulAddPass)

file(GLOB_RECURSE ALL_SOURCE_FILES *.cpp *.h)
add_llvm_pass_plugin(${PluginName}
${ALL_SOURCE_FILES}
DEPENDS
intrinsics_gen
MLIRBuiltinLocationAttributesIncGen
BUILDTREE_ONLY
)

set(MLIR_TEST_DEPENDS ${PluginName} ${MLIR_TEST_DEPENDS} PARENT_SCOPE)
82 changes: 82 additions & 0 deletions mlir/lib/Transforms/lab4/savorina_valeria/SavotinaMulAddPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Tools/Plugins/PassPlugin.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;

namespace {
class SavotinaMulAddPass
: public PassWrapper<SavotinaMulAddPass, OperationPass<ModuleOp>> {
public:
StringRef getArgument() const final { return "SavotinaMulAddPass"; }
StringRef getDescription() const final { return "fma pass"; }

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::math::MathDialect>();
}

void runOnOperation() override {
mlir::ModuleOp module = getOperation();
mlir::OpBuilder builder(module);

auto replaceAndEraseOp = [&](mlir::LLVM::FMulOp &mulOp,
mlir::LLVM::FAddOp &addOp,
mlir::Value &thirdOperand) -> void {
builder.setInsertionPoint(addOp);
mlir::Value fmaResult =
builder.create<mlir::math::FmaOp>(addOp.getLoc(), mulOp.getOperand(0),
mulOp.getOperand(1), thirdOperand);
addOp.replaceAllUsesWith(fmaResult);
addOp.erase();
mulOp.erase();
};

module.walk([&](mlir::Operation *op) {
if (auto addOp = llvm::dyn_cast<mlir::LLVM::FAddOp>(op)) {
mlir::Value addLhs = addOp.getOperand(0);
mlir::Value addRhs = addOp.getOperand(1);

if (!addLhs.getType().isa<mlir::FloatType>() ||
!addRhs.getType().isa<mlir::FloatType>()) {
return;
}

auto isSingleUse = [&](mlir::Value value, mlir::Operation *userOp) {
for (auto &use : value.getUses()) {
if (use.getOwner() != userOp) {
return false;
}
}
return true;
};

if (auto mulOp = addLhs.getDefiningOp<mlir::LLVM::FMulOp>()) {
if (isSingleUse(mulOp->getResult(0), addOp)) {
replaceAndEraseOp(mulOp, addOp, addRhs);
}
} else if (auto mulOp = addRhs.getDefiningOp<mlir::LLVM::FMulOp>()) {
if (isSingleUse(mulOp->getResult(0), addOp)) {
replaceAndEraseOp(mulOp, addOp, addLhs);
}
}
}
});
}
};
} // namespace

MLIR_DECLARE_EXPLICIT_TYPE_ID(SavotinaMulAddPass)
MLIR_DEFINE_EXPLICIT_TYPE_ID(SavotinaMulAddPass)

mlir::PassPluginLibraryInfo getFunctionCallCounterPassPluginInfo() {
return {MLIR_PLUGIN_API_VERSION, "SavotinaMulAddPass", "0.1",
[]() { mlir::PassRegistration<SavotinaMulAddPass>(); }};
}

extern "C" LLVM_ATTRIBUTE_WEAK mlir::PassPluginLibraryInfo
mlirGetPassPluginInfo() {
return getFunctionCallCounterPassPluginInfo();
}
93 changes: 93 additions & 0 deletions mlir/test/Transforms/lab4/savotina_valeria/test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// RUN: mlir-opt -load-pass-plugin=%mlir_lib_dir/SavotinaMulAddPass%shlibext --pass-pipeline="builtin.module(SavotinaMulAddPass)" %s | FileCheck %s

module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi32>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi32>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi32>>, #dlti.dl_entry<f128, dense<128> : vector<2xi32>>, #dlti.dl_entry<i64, dense<64> : vector<2xi32>>, #dlti.dl_entry<f80, dense<128> : vector<2xi32>>, #dlti.dl_entry<i8, dense<8> : vector<2xi32>>, #dlti.dl_entry<i1, dense<8> : vector<2xi32>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi32>>, #dlti.dl_entry<f64, dense<64> : vector<2xi32>>, #dlti.dl_entry<f16, dense<16> : vector<2xi32>>, #dlti.dl_entry<i16, dense<16> : vector<2xi32>>, #dlti.dl_entry<i32, dense<32> : vector<2xi32>>, #dlti.dl_entry<"dlti.stack_alignment", 128 : i32>, #dlti.dl_entry<"dlti.endianness", "little">>} {
// CHECK-LABEL: @_Z8funcZeroddd
llvm.func local_unnamed_addr @_Z8funcZeroddd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.fmul %arg0, %arg2 : f64
%1 = llvm.fadd %0, %arg1 : f64
llvm.return %1 : f64

// CHECK: %0 = math.fma %arg0, %arg2, %arg1 : f64
// CHECK-NEXT: llvm.return %0 : f64
}

// CHECK-LABEL: @_Z7funcOnedddd
llvm.func local_unnamed_addr @_Z7funcOnedddd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}, %arg3: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.fmul %arg0, %arg1 : f64
%1 = llvm.fadd %0, %arg2 : f64
%2 = llvm.fmul %arg2, %arg3 : f64
%3 = llvm.fadd %1, %2 : f64
llvm.return %3 : f64

// CHECK: %0 = math.fma %arg0, %arg1, %arg2 : f64
// CHECK-NEXT: %1 = math.fma %arg2, %arg3, %0 : f64
// CHECK-NEXT: llvm.return %1 : f64
}

// CHECK-LABEL: @_Z7funcTwoddd
llvm.func local_unnamed_addr @_Z7funcTwoddd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.fmul %arg1, %arg2 : f64
%1 = llvm.fadd %arg0, %arg1 : f64
%2 = llvm.fsub %1, %arg2 : f64
%3 = llvm.fadd %0, %arg0 : f64
%4 = llvm.fadd %2, %3 : f64
llvm.return %4 : f64

// CHECK: %0 = llvm.fadd %arg0, %arg1 : f64
// CHECK-NEXT: %1 = llvm.fsub %0, %arg2 : f64
// CHECK-NEXT: %2 = math.fma %arg1, %arg2, %arg0 : f64
// CHECK-NEXT: %3 = llvm.fadd %1, %2 : f64
// CHECK-NEXT: llvm.return %3 : f64
}

// CHECK-LABEL: @_Z9funcThreeddd
llvm.func local_unnamed_addr @_Z9funcThreeddd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.mlir.constant(-2.000000e+00 : f64) : f64
%1 = llvm.fmul %arg0, %arg2 : f64
%2 = llvm.fadd %1, %0 : f64
%3 = llvm.fadd %2, %arg1 : f64
llvm.return %3 : f64

// CHECK: %0 = llvm.mlir.constant(-2.000000e+00 : f64) : f64
// CHECK-NEXT: %1 = math.fma %arg0, %arg2, %0 : f64
// CHECK-NEXT: %2 = llvm.fadd %1, %arg1 : f64
// CHECK-NEXT: llvm.return %2 : f64
}

// CHECK-LABEL: @_Z8funcFourddd
llvm.func local_unnamed_addr @_Z8funcFourddd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.fmul %arg0, %arg2 : f64
%1 = llvm.fmul %0, %arg1 : f64
%2 = llvm.fadd %0, %arg0 : f64
%3 = llvm.fadd %2, %1 : f64
llvm.return %3 : f64

// CHECK: %0 = llvm.fmul %arg0, %arg2 : f64
// CHECK-NEXT: %1 = llvm.fadd %0, %arg0 : f64
// CHECK-NEXT: %2 = math.fma %0, %arg1, %1 : f64
// CHECK-NEXT: llvm.return %2 : f64
}

// CHECK-LABEL: @_Z8funcFiveddd
llvm.func local_unnamed_addr @_Z8funcFiveddd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.fmul %arg0, %arg1 : f64
%1 = llvm.fadd %0, %arg2 : f64
llvm.return %1 : f64

// CHECK: %0 = math.fma %arg0, %arg1, %arg2 : f64
// CHECK-NEXT: llvm.return %0 : f64
}

// CHECK-LABEL: @_Z7funcSixddd
llvm.func local_unnamed_addr @_Z7funcSixddd(%arg0: f64 {llvm.noundef}, %arg1: f64 {llvm.noundef}, %arg2: f64 {llvm.noundef}) -> (f64 {llvm.noundef}) attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", "nounwind", "willreturn", ["uwtable", "2"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"], ["target-features", "+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87"], ["tune-cpu", "generic"]]} {
%0 = llvm.fmul %arg0, %arg1 : f64
%1 = llvm.fadd %0, %arg2 : f64
%2 = llvm.fadd %0, %1 : f64
llvm.return %2 : f64

// CHECK: %0 = llvm.fmul %arg0, %arg1 : f64
// CHECK-NEXT: %1 = llvm.fadd %0, %arg2 : f64
// CHECK-NEXT: %2 = llvm.fadd %0, %1 : f64
// CHECK-NEXT: llvm.return %2 : f64
}
}
Loading