diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 82c94c717109..b388b9ceb9b5 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -114,7 +114,7 @@ EncodingAttr EncodingAttr::clone(AffineMap bcastMap) { AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts()); } -EncodingAttr EncodingAttr::cloneWithLayouts(SmallVector layouts) { +EncodingAttr EncodingAttr::cloneWithLayouts(ArrayRef layouts) { MLIRContext *ctx = getContext(); return get(ctx, getOperandIndex(), getOpType(), getElementTypes(), /*user_indexing_maps=*/ArrayAttr(), diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td index b674626f77e9..434356a7e66c 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td @@ -116,7 +116,7 @@ def EncodingAttr : /// Clones an encoding with a new layout list and drops other optional /// parameters (because they are resolved). - EncodingAttr cloneWithLayouts(SmallVector layouts); + EncodingAttr cloneWithLayouts(ArrayRef layouts); }]; let genVerifyDecl = 0; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp index 581ba93f7fbc..e28d08fb9a89 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp @@ -15,6 +15,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" @@ -121,19 +122,18 @@ class HALAffinityAnalysisDialectInterface : public IREE::Stream::AffinityAnalysisDialectInterface { public: using AffinityAnalysisDialectInterface::AffinityAnalysisDialectInterface; - IREE::Stream::LayoutAttrSolverFn - makeLayoutAttrSolver(ModuleOp moduleOp) const { - return [=](IREE::Stream::AffinityAttr aff, Operation *op, - SetVector &layoutAttrs) { - // TODO: This needs to be in the lambda. Otherwise, it could crash because - // the root op (i.e., the original moduleOp) could be outdated. + IREE::Stream::ResolveLayoutAttrFn + makeLayoutAttrResolver(ModuleOp moduleOp) const { + return [=](IREE::Stream::AffinityAttr affinityAttr, Operation *op, + SetVector &layoutAttrs) -> LogicalResult { + // This needs to be in the lambda because the moduleOp could be modified.. IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); if (failed(deviceAnalysis.run())) { - op->emitError("failed to run DeviceAnalysis"); - return failure(); + return op->emitError("failed to run DeviceAnalysis"); } SetVector resultSet; - deviceAnalysis.gatherRequiredExecutableTargets(aff, op, resultSet); + deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op, + resultSet); // TODO(hanchung): Populate the EncodingLayoutAttr when it is ready. layoutAttrs.insert(resultSet.begin(), resultSet.end()); return success(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h index 9bdcd3e232ad..ea7db9172687 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.h @@ -16,7 +16,7 @@ namespace mlir::iree_compiler::IREE::Stream { -using LayoutAttrSolverFn = std::function &)>; class AffinityAnalysisDialectInterface @@ -24,7 +24,11 @@ class AffinityAnalysisDialectInterface public: AffinityAnalysisDialectInterface(Dialect *dialect) : Base(dialect) {} - virtual LayoutAttrSolverFn makeLayoutAttrSolver(ModuleOp moduleOp) const = 0; + /// The `moduleOp` must remain live and unmodified for as long as the returned + /// capture is. Otherwise, it will likely be incorrect or crash if the module + /// op is mutated, especially when module scope analysis is run. + virtual ResolveLayoutAttrFn + makeLayoutAttrResolver(ModuleOp moduleOp) const = 0; }; } // namespace mlir::iree_compiler::IREE::Stream diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp index b1fa65904272..92431666e4ed 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp @@ -54,16 +54,16 @@ SmallVector gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) { // TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType. static RankedTensorType cloneWithEncoding(RankedTensorType type, - Attribute encoding) { + Attribute encodingAttr) { return RankedTensorType::get(type.getShape(), type.getElementType(), - encoding); + encodingAttr); } -static LogicalResult -addLayoutsToTensorPhaseOps(ModuleOp moduleOp, FunctionOpInterface funcOp, - LayoutAttrSolverFn makeLayoutAttrFn) { - SmallVector candidates; - funcOp.walk([&](AffinityOpInterface affinityOp) { +static LogicalResult addLayoutsToTensorPhaseOps( + ModuleOp moduleOp, FunctionOpInterface funcOp, + IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) { + SmallVector candidates; + funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) { // Only need to update encoding types for ops that have TensorPhaseOp trait. if (!affinityOp->hasTrait()) { return; @@ -73,8 +73,8 @@ addLayoutsToTensorPhaseOps(ModuleOp moduleOp, FunctionOpInterface funcOp, // TODO(hanchung): We should use the default device in this case. However, // it is not guaranteed that default device attribute will always be set in // the IR. (Is the statement correct?) - auto affAttr = affinityOp.getAffinityAttr(); - if (!affAttr) { + auto affinityAttr = affinityOp.getAffinityAttr(); + if (!affinityAttr) { return; } candidates.push_back(affinityOp); @@ -86,47 +86,48 @@ addLayoutsToTensorPhaseOps(ModuleOp moduleOp, FunctionOpInterface funcOp, IRRewriter rewriter(funcOp.getContext()); for (auto affinityOp : candidates) { - auto affAttr = affinityOp.getAffinityAttr(); + auto affinityAttr = affinityOp.getAffinityAttr(); SetVector layouts; - if (failed(makeLayoutAttrFn(affAttr, moduleOp, layouts))) { - affinityOp.emitError("failed on making layouts"); - return failure(); + if (failed(resolveLayoutAttr(affinityAttr, moduleOp, layouts))) { + return affinityOp.emitError("failed on making layouts"); } + // Returns an updated encoding attribute if an encoding attribute is present + // in the type. Otherwise, returns std::nullopt. auto getEncodingWithNewLayouts = [=](Type type) -> std::optional { auto rankedTensorType = dyn_cast(type); if (!rankedTensorType) { return std::nullopt; } - auto encoding = IREE::Encoding::getEncodingAttr(rankedTensorType); - if (!encoding) { + auto encodingAttr = IREE::Encoding::getEncodingAttr(rankedTensorType); + if (!encodingAttr) { return std::nullopt; } - SmallVector attrs(layouts.begin(), layouts.end()); - return encoding.cloneWithLayouts(attrs); + return encodingAttr.cloneWithLayouts(layouts.getArrayRef()); }; + // TODO(hanchung): Update other Stream operations. LogicalResult result = TypeSwitch(affinityOp) - .Case([&](auto sizeOfOp) { + .Case([&](auto sizeOfOp) { auto encodingType = dyn_cast(sizeOfOp.getEncoding()); if (!encodingType) { return success(); } - std::optional encoding = + std::optional encodingAttr = getEncodingWithNewLayouts(encodingType); - if (!encoding.has_value()) { + if (!encodingAttr.has_value()) { return success(); } rewriter.modifyOpInPlace(sizeOfOp, [&] { sizeOfOp.setEncoding( - cloneWithEncoding(encodingType, encoding.value())); + cloneWithEncoding(encodingType, encodingAttr.value())); }); return success(); }) - .Default([](auto *op) { return success(); }); + .Default([](auto *op) { return failure(); }); if (failed(result)) { return failure(); @@ -140,25 +141,24 @@ struct SpecializeEncodingsPass : public impl::SpecializeEncodingsPassBase { void runOnOperation() override { ModuleOp moduleOp = getOperation(); - auto usedDialects = - gatherUsedDialectInterfaces(moduleOp); + auto usedDialects = gatherUsedDialectInterfaces< + IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp); if (usedDialects.size() != 1) { moduleOp.emitError("expected only one dialect implementing " "AffinityAnalysisDialectInterface"); return signalPassFailure(); } - SymbolTable symbolTable(moduleOp); llvm::MapVector executableOps; for (auto executableOp : moduleOp.getOps()) { executableOps[executableOp.getName()] = executableOp; } - LayoutAttrSolverFn makeLayoutAttrFn = - usedDialects[0]->makeLayoutAttrSolver(moduleOp); - for (auto funcOp : moduleOp.getOps()) { - if (failed( - addLayoutsToTensorPhaseOps(moduleOp, funcOp, makeLayoutAttrFn))) { + IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr = + usedDialects[0]->makeLayoutAttrResolver(moduleOp); + for (auto funcOp : moduleOp.getOps()) { + if (failed(addLayoutsToTensorPhaseOps(moduleOp, funcOp, + resolveLayoutAttr))) { funcOp.emitError( "failed on adding layouts to Stream::TensorPhaseOp with encodings"); return signalPassFailure();