Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar committed Feb 9, 2025
1 parent 19fca49 commit 0f9bd38
Showing 1 changed file with 26 additions and 30 deletions.
56 changes: 26 additions & 30 deletions compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,41 +333,37 @@ static bool hasCompatibleOuterParallelLoops(
consumerIndexingMap, rootOuterParallelLoops);
}

/// For all uses of an operation, finds the use that dominates all other uses.
/// For all uses of an operation, return the uses that could be fused.
/// The returned vector contains the uses in dominance order.
static SmallVector<OpOperand *>
getFusableUse(Operation *op, DominanceInfo const &dominanceInfo,
bool aggressiveFusion) {
getFusableUses(MLIRContext *context, Operation *op,
DominanceInfo const &dominanceInfo, bool aggressiveFusion) {
if (!aggressiveFusion && llvm::count_if(op->getUses(), [](OpOperand &use) {
return !isa<tensor::DimOp>(use.getOwner());
}) != 1) {
return {};
}

// Collect non-dim users.
SetVector<Operation *> ignoredUsers;
for (Operation *user : op->getUsers()) {
if (isa<tensor::DimOp>(user)) {
ignoredUsers.insert(user);
}
}

// Find the use in a non-dim user that dominates all other non-dim users.
SmallVector<OpOperand *> fusableUses;
for (auto &use : op->getUses()) {
// Collect all fusable user candidates.
SetVector<OpOperand *> fusableUses;
for (OpOperand &use : op->getUses()) {
Operation *user = use.getOwner();
if (ignoredUsers.contains(user)) {
if (isa<tensor::CollapseShapeOp, tensor::DimOp, tensor::ExpandShapeOp>(
user) ||
user->getDialect() ==
context->getLoadedDialect<IREE::Flow::FlowDialect>()) {
continue;
}
if (llvm::all_of(op->getUsers(), [&](Operation *otherUser) {
if (user == otherUser || ignoredUsers.contains(otherUser))
return true;
return dominanceInfo.dominates(user, otherUser);
})) {
fusableUses.push_back(&use);
ignoredUsers.insert(user);
}
fusableUses.insert(&use);
}
return fusableUses;

SmallVector<OpOperand *> usesVec = fusableUses.takeVector();
llvm::sort(usesVec, [&](OpOperand *lhsUse, OpOperand *rhsUse) {
return dominanceInfo.properlyDominates(lhsUse->getOwner(),
rhsUse->getOwner());
});

return usesVec;
}

/// Returns true if the operands are fusable.
Expand Down Expand Up @@ -590,8 +586,8 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef<Operation *> roots,
Operation *currRoot = workList.pop_back_val();

SmallVector<OpOperand *> fusableUses =
getFusableUse(currRoot, dominanceInfo,
/*aggressiveFusion=*/options.aggressiveFusion);
getFusableUses(context, currRoot, dominanceInfo,
/*aggressiveFusion=*/options.aggressiveFusion);
if (fusableUses.empty())
continue;

Expand Down Expand Up @@ -708,8 +704,8 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, unsigned groupNum,
}

SmallVector<OpOperand *> fusableUses =
getFusableUse(producer, dominanceInfo,
/*aggressiveFusion=*/options.aggressiveFusion);
getFusableUses(context, producer, dominanceInfo,
/*aggressiveFusion=*/options.aggressiveFusion);
if (fusableUses.empty() || fusableUses.front()->getOwner() != candidate)
continue;

Expand Down Expand Up @@ -851,8 +847,8 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter,
SmallVector<IREE::Flow::DispatchRegionOp> regionOps;
for (auto [rootIndex, root] : llvm::enumerate(roots)) {

// Sort producers topologically. All producers must be in the same block
// as the root.
// Sort producers and consumers topologically. All fused ops must be in the
// same block as the root.
SmallVector<Operation *> &currFusedOperations = fusedOperations[rootIndex];
bool sortResult = mlir::computeTopologicalSorting(currFusedOperations);
(void)sortResult;
Expand Down

0 comments on commit 0f9bd38

Please sign in to comment.