Skip to content

Commit

Permalink
Restrict multi-consumer fusion to only the horizontal fusion cases.
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar committed Feb 12, 2025
1 parent b9e403a commit e92f8aa
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
7 changes: 6 additions & 1 deletion compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,12 @@ isContractionOpSequence(Value yielded) {
/// Recognize an operation that is horizontally fused contraction.
/// TODO: The logic below is quite convoluted. Might be better
/// off having a dedicated operation for this.
bool isaHorizontallyFusedContraction(linalg::LinalgOp linalgOp) {
bool isaHorizontallyFusedContraction(Operation *op) {
auto linalgOp = dyn_cast_or_null<linalg::GenericOp>(op);
if (!linalgOp) {
return false;
}

if (linalgOp->getNumResults() == 1) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ bool isGatherlikeOp(Operation *op);
/// Check if a given operation is a horizontally fused contraction operation.
/// The expectation is that the LHS is common, and all the operands are
/// different RHS.
bool isaHorizontallyFusedContraction(linalg::LinalgOp genericOp);
bool isaHorizontallyFusedContraction(Operation *op);

} // namespace mlir::iree_compiler::IREE::LinalgExt
#endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,13 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef<Operation *> roots,
if (fusableUses.empty())
continue;

// For now disable the fusing with multiple consumers for all
// operations other than horizontally fused gemms. This should
// work in general but is causing time-outs on some CI examples.
if (!IREE::LinalgExt::isaHorizontallyFusedContraction(root)) {
fusableUses = {fusableUses.front()};
}

// Analyse the use to see if it is fusable.
for (OpOperand *fusableUse : fusableUses) {
Operation *consumerOp = fusableUse->getOwner();
Expand Down

0 comments on commit e92f8aa

Please sign in to comment.