Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of vecmat/matvec in SetEncoding and MaterializeEncoding #15257

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 73 additions & 96 deletions compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,25 +146,29 @@ static Value unsetEncodingAndExtractSlice(OpBuilder &builder, Location loc,

namespace {

/// Rewrites the matmul op to work on tensors with encoding. Optionally
/// Rewrites contraction ops to work on tensors with encoding. Optionally
/// also pads the operands.
struct SetMatmulEncoding : public OpRewritePattern<linalg::MatmulOp> {
SetMatmulEncoding(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<linalg::MatmulOp>(context, benefit) {}

LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
/// Currently works on matmul, batch_matmul, vecmat, matvec and batch_matvec.
struct SetContractionOpEncoding
: public OpInterfaceRewritePattern<linalg::ContractionOpInterface> {
using OpInterfaceRewritePattern<
linalg::ContractionOpInterface>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::ContractionOpInterface op,
PatternRewriter &rewriter) const override {
if (!matmulOp.hasTensorSemantics())
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
if (!linalgOp || !linalgOp.hasTensorSemantics())
return failure();
auto inputs = matmulOp.getDpsInputs();
auto outputs = matmulOp.getDpsInits();

auto inputs = linalgOp.getDpsInputs();
auto outputs = linalgOp.getDpsInits();
auto hasEncoding = [](Value operand) -> bool {
auto type = llvm::dyn_cast<RankedTensorType>(operand.getType());
return type && type.getEncoding();
};
if (llvm::any_of(inputs, hasEncoding) ||
llvm::any_of(outputs, hasEncoding)) {
return failure();
return rewriter.notifyMatchFailure(op,
"input/output already has encoding");
}

Value origLhs = inputs[0];
Expand All @@ -182,85 +186,26 @@ struct SetMatmulEncoding : public OpRewritePattern<linalg::MatmulOp> {
Type outElemType = getElemType(origOut);

if (!lhsElemType || !rhsElemType || !outElemType) {
return failure();
return rewriter.notifyMatchFailure(op, "invalid input/output");
}

IREE::LinalgExt::EncodingUser user = IREE::LinalgExt::EncodingUser::MATMUL;
Location loc = matmulOp.getLoc();
TypeRange operandTypes = matmulOp->getOperandTypes();
Value encodedLhs =
padAndSetEncoding(rewriter, loc, origLhs, user,
IREE::LinalgExt::EncodingRole::LHS, operandTypes);
Value encodedRhs =
padAndSetEncoding(rewriter, loc, origRhs, user,
IREE::LinalgExt::EncodingRole::RHS, operandTypes);
Value encodedOut =
padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT, operandTypes);

Value matmulTiled = rewriter
.create<linalg::MatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);

// Sizes are computed by original output size.
FailureOr<SmallVector<OpFoldResult>> origOutSizes =
IREE::LinalgExt::getDims(rewriter, loc, origOut);
if (failed(origOutSizes)) {
return rewriter.notifyMatchFailure(matmulOp,
"failed to get shape of result");
IREE::LinalgExt::EncodingUser user;
if (op.isRowMajorMatmul() || op.isColumnMajorMatmul()) {
user = IREE::LinalgExt::EncodingUser::MATMUL;
} else if (op.isRowMajorBatchMatmul()) {
user = IREE::LinalgExt::EncodingUser::BATCH_MATMUL;
} else if (op.isVecmat()) {
user = IREE::LinalgExt::EncodingUser::VECMAT;
} else if (op.isMatvec()) {
user = IREE::LinalgExt::EncodingUser::MATVEC;
} else if (op.isBatchMatvec()) {
user = IREE::LinalgExt::EncodingUser::BATCH_MATVEC;
} else {
return rewriter.notifyMatchFailure(op, "unsupported contraction op");
}

Value result = unsetEncodingAndExtractSlice(rewriter, loc, matmulTiled,
origOutSizes.value());

rewriter.replaceOp(matmulOp, result);
return success();
}
};

struct SetBatchMatmulEncoding : public OpRewritePattern<linalg::BatchMatmulOp> {
SetBatchMatmulEncoding(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<linalg::BatchMatmulOp>(context, benefit) {}

LogicalResult matchAndRewrite(linalg::BatchMatmulOp matmulOp,
PatternRewriter &rewriter) const override {
if (!matmulOp.hasTensorSemantics())
return failure();
auto inputs = matmulOp.getDpsInputs();
auto outputs = matmulOp.getDpsInits();
auto hasEncoding = [](Value operand) -> bool {
auto type = llvm::dyn_cast<RankedTensorType>(operand.getType());
return type && type.getEncoding();
};
if (llvm::any_of(inputs, hasEncoding) ||
llvm::any_of(outputs, hasEncoding)) {
return failure();
}

Value origLhs = inputs[0];
Value origRhs = inputs[1];
Value origOut = outputs[0];

auto getElemType = [](Value v) -> Type {
if (auto tensorType = llvm::dyn_cast<RankedTensorType>(v.getType())) {
return tensorType.getElementType();
}
return {};
};
Type lhsElemType = getElemType(origLhs);
Type rhsElemType = getElemType(origRhs);
Type outElemType = getElemType(origOut);

if (!lhsElemType || !rhsElemType || !outElemType) {
return failure();
}

IREE::LinalgExt::EncodingUser user =
IREE::LinalgExt::EncodingUser::BATCH_MATMUL;
Location loc = matmulOp.getLoc();
TypeRange operandTypes = matmulOp->getOperandTypes();
Location loc = op.getLoc();
TypeRange operandTypes = op->getOperandTypes();
Value encodedLhs =
padAndSetEncoding(rewriter, loc, origLhs, user,
IREE::LinalgExt::EncodingRole::LHS, operandTypes);
Expand All @@ -271,24 +216,56 @@ struct SetBatchMatmulEncoding : public OpRewritePattern<linalg::BatchMatmulOp> {
padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT, operandTypes);

Value matmulTiled = rewriter
.create<linalg::BatchMatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
Value opTiled;
switch (user) {
case IREE::LinalgExt::EncodingUser::MATMUL:
opTiled = rewriter
.create<linalg::MatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;

case IREE::LinalgExt::EncodingUser::BATCH_MATMUL:
opTiled = rewriter
.create<linalg::BatchMatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;
case IREE::LinalgExt::EncodingUser::VECMAT:
opTiled = rewriter
.create<linalg::VecmatOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;
case IREE::LinalgExt::EncodingUser::MATVEC:
opTiled = rewriter
.create<linalg::MatvecOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;
case IREE::LinalgExt::EncodingUser::BATCH_MATVEC:
opTiled = rewriter
.create<linalg::BatchMatvecOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;
}
Comment on lines +220 to +257
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


// Sizes are computed by original output size.
FailureOr<SmallVector<OpFoldResult>> origOutSizes =
IREE::LinalgExt::getDims(rewriter, loc, origOut);
if (failed(origOutSizes)) {
return rewriter.notifyMatchFailure(matmulOp,
"failed to get shape of result");
return rewriter.notifyMatchFailure(op, "failed to get shape of result");
}

Value result = unsetEncodingAndExtractSlice(rewriter, loc, matmulTiled,
Value result = unsetEncodingAndExtractSlice(rewriter, loc, opTiled,
origOutSizes.value());

rewriter.replaceOp(matmulOp, result);
rewriter.replaceOp(linalgOp, result);
return success();
}
};
Expand Down Expand Up @@ -332,7 +309,7 @@ void SetEncodingPass::runOnOperation() {
MLIRContext *context = &getContext();
{
RewritePatternSet patterns(context);
patterns.insert<SetBatchMatmulEncoding, SetMatmulEncoding>(context);
patterns.insert<SetContractionOpEncoding>(context);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldFillWithSetEncoding>(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
Expand Down
Loading