forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][mlprogram] Add
mlprogram-pipeline-globals
optimization pass
Added pass optimizes MLProgram global operations by reducing to only the minimal load/store operations for global tensors. This avoids unnecessary global operations throughout a program and potentially improves operation gusion. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D159228
- Loading branch information
Showing
10 changed files
with
582 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(IR) | ||
add_subdirectory(Transforms) |
6 changes: 6 additions & 0 deletions
6
mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name MLProgram) | ||
add_public_tablegen_target(MLIRMLProgramPassIncGen) | ||
add_dependencies(mlir-headers MLIRMLProgramPassIncGen) | ||
|
||
add_mlir_doc(Passes MLProgramPasses ./ -gen-pass-doc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_ | ||
#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_ | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
namespace ml_program { | ||
|
||
#define GEN_PASS_DECL | ||
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Registration | ||
//===----------------------------------------------------------------------===// | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass(); | ||
|
||
/// Generate the code for registering passes. | ||
#define GEN_PASS_REGISTRATION | ||
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" | ||
|
||
} // namespace ml_program | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===-- Passes.td - pass definition file -------------------*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES | ||
#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> { | ||
let summary = "Optimize `ml_program` global operations for read and store"; | ||
let description = [{ | ||
`ml_program`'s load and store operations can be optimized for | ||
write-write or write-read sets of operations. This allows known | ||
tensors to not be re-read when the value is already known in IR. | ||
|
||
The pass is designed to handle both nested regions and function calls | ||
safely. | ||
}]; | ||
let constructor = "mlir::ml_program::createMLProgramPipelineGlobalsPass()"; | ||
} | ||
|
||
#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(IR) | ||
add_subdirectory(Transforms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
add_mlir_dialect_library(MLIRMLProgramTransforms | ||
PipelineGlobalOps.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram/Transforms | ||
|
||
DEPENDS | ||
MLIRMLProgramPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRMLProgramDialect | ||
MLIRPass | ||
) |
234 changes: 234 additions & 0 deletions
234
mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/MLProgram/Transforms/Passes.h" | ||
|
||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h" | ||
#include "mlir/Dialect/MLProgram/Transforms/Passes.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir { | ||
namespace ml_program { | ||
#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS | ||
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
class MLProgramPipelineGlobals | ||
: public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> { | ||
public: | ||
void runOnOperation() override; | ||
|
||
private: | ||
LogicalResult buildGlobalMap(ModuleOp op); | ||
|
||
void ProcessBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, | ||
llvm::DenseSet<SymbolRefAttr> &symbolStore); | ||
|
||
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap; | ||
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap; | ||
}; | ||
|
||
// Traverses upwards searchign for the operation mapped by the symbol. | ||
static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) { | ||
for (auto op = baseOp; op; op = op->getParentOp()) { | ||
auto lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol); | ||
if (lookup) | ||
return lookup; | ||
} | ||
return nullptr; | ||
} | ||
|
||
// Builds map from a symbol to MLProgram global symbols loaded or stored | ||
// during processing. | ||
LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) { | ||
llvm::DenseMap<SymbolRefAttr, Operation *> callableMap; | ||
auto res = module->walk([&](Operation *op) { | ||
if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) { | ||
auto callable = caller.getCallableForCallee(); | ||
// For now we do not know how to handle Value based tracing, so fail. | ||
if (mlir::isa<Value>(callable)) { | ||
return WalkResult::interrupt(); | ||
} | ||
|
||
auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable); | ||
auto func = getFromSymbol(op, symbol); | ||
callableMap[symbol] = func; | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
|
||
if (res.wasInterrupted()) { | ||
return failure(); | ||
} | ||
|
||
// First grab all symbols loaded or stored by each function. This | ||
// will not handle calls initially. | ||
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols; | ||
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols; | ||
for (auto callable : callableMap) { | ||
llvm::DenseSet<SymbolRefAttr> loadSymbols; | ||
llvm::DenseSet<SymbolRefAttr> storeSymbols; | ||
|
||
callable.getSecond()->walk( | ||
[&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); }); | ||
|
||
callable.getSecond()->walk( | ||
[&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); }); | ||
|
||
opLoadSymbols[callable.getFirst()] = std::move(loadSymbols); | ||
opStoreSymbols[callable.getFirst()] = std::move(storeSymbols); | ||
} | ||
|
||
// For each callable function we find each global loaded/stored within the | ||
// function or a nested called function. This includes recursion checking to | ||
// avoid infinitely recursing. | ||
for (auto callable : callableMap) { | ||
SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first); | ||
llvm::SmallVector<SymbolRefAttr> work = {thisSymbol}; | ||
llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol}; | ||
llvm::DenseSet<SymbolRefAttr> loadSymbols; | ||
llvm::DenseSet<SymbolRefAttr> storeSymbols; | ||
|
||
for (size_t i = 0; i < work.size(); ++i) { | ||
callableMap[work[i]]->walk([&](CallOpInterface call) { | ||
auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee()); | ||
if (!visited.contains(symbol)) { | ||
visited.insert(symbol); | ||
work.push_back(symbol); | ||
} | ||
}); | ||
|
||
for (auto load : opLoadSymbols[work[i]]) | ||
loadSymbols.insert(load); | ||
|
||
for (auto store : opStoreSymbols[work[i]]) | ||
storeSymbols.insert(store); | ||
} | ||
|
||
loadSymbolsMap[thisSymbol] = std::move(loadSymbols); | ||
storeSymbolsMap[thisSymbol] = std::move(storeSymbols); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
// Process each operation in the block deleting unneeded loads / stores, | ||
// recursing on subblocks and checking function calls. | ||
void MLProgramPipelineGlobals::ProcessBlock( | ||
Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad, | ||
llvm::DenseSet<SymbolRefAttr> &symbolStore) { | ||
|
||
llvm::DenseMap<SymbolRefAttr, Value> previousLoads; | ||
llvm::DenseMap<SymbolRefAttr, Operation *> previousStores; | ||
llvm::SmallVector<Operation *> toDelete; | ||
for (auto &op : block) { | ||
// If this is a global load, remap to a previous value if known | ||
// and delete this load. Remember that this value is the currently | ||
// known load. | ||
if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) { | ||
auto ref = load.getGlobal(); | ||
symbolLoad.insert(ref); | ||
if (previousLoads.contains(ref)) { | ||
toDelete.push_back(&op); | ||
load.getResult().replaceAllUsesWith(previousLoads[ref]); | ||
} else { | ||
previousLoads[ref] = load.getResult(); | ||
} | ||
continue; | ||
} | ||
|
||
// Delete a previous store if it exists and is not needed, update | ||
// the most recent known value for this global ref. | ||
if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) { | ||
auto ref = store.getGlobal(); | ||
symbolStore.insert(ref); | ||
if (previousStores.contains(ref)) { | ||
toDelete.push_back(previousStores.find(ref)->getSecond()); | ||
} | ||
|
||
previousLoads[ref] = store.getValue(); | ||
previousStores[ref] = &op; | ||
continue; | ||
} | ||
|
||
// If a function is called, clear known values for loads/stores used by | ||
// the function or its sub-functions. | ||
if (auto call = mlir::dyn_cast<CallOpInterface>(op)) { | ||
auto loadSymbols = | ||
loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; | ||
auto storeSymbols = | ||
storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())]; | ||
|
||
for (auto sym : loadSymbols) { | ||
previousStores.erase(sym); | ||
} | ||
|
||
for (auto sym : storeSymbols) { | ||
previousLoads.erase(sym); | ||
previousStores.erase(sym); | ||
} | ||
continue; | ||
} | ||
|
||
// If the op has sub-regions, recurse inside. We make no guarantees whether | ||
// the recursion occurs. | ||
llvm::DenseSet<SymbolRefAttr> opSymbolLoad; | ||
llvm::DenseSet<SymbolRefAttr> opSymbolStore; | ||
for (auto ®ion : op.getRegions()) { | ||
for (auto &block : region) { | ||
ProcessBlock(block, opSymbolLoad, opSymbolStore); | ||
} | ||
} | ||
|
||
// Update current state from the subblock. | ||
for (auto change : opSymbolLoad) { | ||
symbolLoad.insert(change); | ||
previousStores.erase(change); | ||
} | ||
|
||
for (auto change : opSymbolStore) { | ||
symbolStore.insert(change); | ||
previousLoads.erase(change); | ||
previousStores.erase(change); | ||
} | ||
} | ||
|
||
for (auto op : toDelete) { | ||
op->erase(); | ||
} | ||
} | ||
|
||
void MLProgramPipelineGlobals::runOnOperation() { | ||
auto targetOp = getOperation(); | ||
if (failed(buildGlobalMap(targetOp))) { | ||
return; | ||
} | ||
|
||
for (auto &funcOp : *targetOp.getBody()) { | ||
for (auto ®ion : funcOp.getRegions()) { | ||
for (auto &block : region.getBlocks()) { | ||
llvm::DenseSet<SymbolRefAttr> symbolsLoaded; | ||
llvm::DenseSet<SymbolRefAttr> symbolsStored; | ||
ProcessBlock(block, symbolsLoaded, symbolsStored); | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<mlir::ModuleOp>> | ||
createMLProgramPipelineGlobalsPass() { | ||
return std::make_unique<MLProgramPipelineGlobals>(); | ||
} | ||
|
||
} // namespace ml_program | ||
} // namespace mlir |
Oops, something went wrong.