Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DispatchCreation] Changes to dispatch region in preparation for horizontal fusion changes. #19876

Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void TensorDimTrackingRewriter::notifyOperationErased(Operation *op) {
void TensorDimTrackingRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
IRRewriter::Listener::notifyOperationInserted(op, previous);
if (isa<tensor::DimOp>(op))
auto dimOp = dyn_cast<tensor::DimOp>(op);
if (dimOp && isa<OpResult>(dimOp.getSource()))
dimOps.insert(op);
}

Expand All @@ -60,16 +61,21 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter,
std::optional<int64_t> idx = dimOp.getConstantIndex();
if (!idx.has_value())
continue;

if (isa<BlockArgument>(dimOp.getSource())) {
continue;
}

// Only DimOps with ranked tensors are supported.
auto tensorType =
llvm::dyn_cast<RankedTensorType>(dimOp.getSource().getType());
if (!tensorType)
continue;

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(dimOp);
if (!tensorType.isDynamicDim(*idx)) {
// Rewrite static dimension with constant.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(dimOp);
int64_t size = tensorType.getShape()[*idx];
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(dimOp, size);
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,8 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
// Value is an OpResult.
Operation *op = value.getDefiningOp();
OpResult opResult = llvm::cast<OpResult>(value);
b.setInsertionPoint(op);

// Case 3: Value is tied. Reify the dimensions of the tied operand.
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
if (tiedOp) {
Value tiedOperand = tiedOp.getTiedResultOperand(value);
if (tiedOperand && tiedOperand.getType() == value.getType())
return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims,
createTensorDimOps);
}

// Case 4: Query ShapeAwareOpInterface.
// Case 3: Query ShapeAwareOpInterface.
auto shapeAwareOp = dyn_cast<IREE::Util::ShapeAwareOpInterface>(op);
if (shapeAwareOp) {
ValueRange dims =
Expand All @@ -286,6 +276,15 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
return success();
}

// Case 4: Value is tied. Reify the dimensions of the tied operand.
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
if (tiedOp) {
Value tiedOperand = tiedOp.getTiedResultOperand(value);
if (tiedOperand && tiedOperand.getType() == value.getType())
return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims,
/*createTensorDimOps=*/true);
}

// Case 5: Query ReifyRankedShapedTypeOpInterface.
auto reifyShapeOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
if (reifyShapeOp) {
Expand All @@ -308,8 +307,14 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
}

/// Reify the dynamic dimensions of the given value.
/// Deprecated. Use `getOptimizedDynamicResultDims` instead.
LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value,
SmallVectorImpl<Value> &dynamicDims) {

OpBuilder::InsertionGuard g(b);
if (auto op = value.getDefiningOp()) {
b.setInsertionPoint(op);
}
return reifyDynamicResultDimsImpl(b, value, dynamicDims,
/*createTensorDimOps=*/true);
}
Expand Down Expand Up @@ -473,7 +478,7 @@ movePrecedingOpsIntoDispatchRegion(RewriterBase &rewriter,
rewriter.setInsertionPoint(target);
SmallVector<Value> &dims =
dispatchOpNewResultsDynamicDims.emplace_back();
if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) {
return target->emitOpError(
"failed to reify dynamic dims of result to be yielded from "
"dispatch region");
Expand Down Expand Up @@ -554,9 +559,10 @@ moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target,
for (auto [index, result] : llvm::enumerate(target->getResults())) {
replacedValues.push_back(result);
yieldedResults.push_back(clonedTarget->getResult(index));
rewriter.setInsertionPoint(target);
OpBuilder::InsertionGuard g1(rewriter);
rewriter.setInsertionPoint(regionOp);
SmallVector<Value> &dims = dispatchOpNewResultsDynamicDims.emplace_back();
if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) {
return target->emitOpError(
"failed to reify dynamic dims of result to be yielded from "
"dispatch region");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,12 @@ isContractionOpSequence(Value yielded) {
/// Recognize an operation that is horizontally fused contraction.
/// TODO: The logic below is quite convoluted. Might be better
/// off having a dedicated operation for this.
bool isaHorizontallyFusedContraction(linalg::LinalgOp linalgOp) {
bool isaHorizontallyFusedContraction(Operation *op) {
auto linalgOp = dyn_cast_or_null<linalg::GenericOp>(op);
if (!linalgOp) {
return false;
}

if (linalgOp->getNumResults() == 1) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ bool isGatherlikeOp(Operation *op);
/// Check if a given operation is a horizontally fused contraction operation.
/// The expectation is that the LHS is common, and all the operands are
/// different RHS.
bool isaHorizontallyFusedContraction(linalg::LinalgOp genericOp);
bool isaHorizontallyFusedContraction(Operation *op);

} // namespace mlir::iree_compiler::IREE::LinalgExt
#endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_
Loading
Loading