-
Notifications
You must be signed in to change notification settings - Fork 12.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][sparse] implementating stageSparseOpPass as an interface (#69022)
- Loading branch information
1 parent
a22a1fe
commit 761c9dd
Showing
11 changed files
with
299 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
//===- SparseTensorInterfaces.h - sparse tensor operations | ||
//interfaces-------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ | ||
#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ | ||
|
||
#include "mlir/IR/OpDefinition.h" | ||
|
||
namespace mlir { | ||
class PatternRewriter; | ||
|
||
namespace sparse_tensor { | ||
class StageWithSortSparseOp; | ||
|
||
namespace detail { | ||
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op, | ||
PatternRewriter &rewriter); | ||
} // namespace detail | ||
} // namespace sparse_tensor | ||
} // namespace mlir | ||
|
||
/// Include the generated interface declarations. | ||
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc" | ||
|
||
#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ |
45 changes: 45 additions & 0 deletions
45
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
//===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES | ||
#define SPARSETENSOR_IR_SPARSETENSORINTERFACES | ||
|
||
include "mlir/IR/OpBase.td" | ||
|
||
def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> { | ||
let description = [{ | ||
A stage-with-sort sparse tensor operation is an operation that produces | ||
unordered intermediate output. An extra sort is required to obtain the final | ||
ordered result. | ||
|
||
E.g., convert csr -> csc need to be implemented as | ||
convert csr -> unordered coo -> sort by column -> csc; and | ||
concatenate csr, csc -> csr can be staged into | ||
concatenate csr, csr -> unordered coo -> sort by row -> csr. | ||
}]; | ||
let cppNamespace = "::mlir::sparse_tensor"; | ||
let methods = [ | ||
InterfaceMethod< | ||
/*desc=*/"Return true if the operation needs an extra sort to produce the final result.", | ||
/*retTy=*/"bool", | ||
/*methodName=*/"needsExtraSort", | ||
/*args=*/(ins), | ||
/*methodBody=*/"">, | ||
InterfaceMethod< | ||
/*desc=*/"Stage the operation, return the final result value after staging.", | ||
/*retTy=*/"::mlir::LogicalResult", | ||
/*methodName=*/"stageWithSort", | ||
/*args=*/(ins "::mlir::PatternRewriter &":$rewriter), | ||
/*methodBody=*/[{ | ||
return detail::stageWithSortImpl($_op, rewriter); | ||
}]>, | ||
]; | ||
} | ||
|
||
|
||
#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" | ||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" | ||
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::sparse_tensor; | ||
|
||
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc" | ||
|
||
LogicalResult | ||
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op, | ||
PatternRewriter &rewriter) { | ||
if (!op.needsExtraSort()) | ||
return failure(); | ||
|
||
Location loc = op.getLoc(); | ||
Type finalTp = op->getOpResult(0).getType(); | ||
SparseTensorType dstStt(finalTp.cast<RankedTensorType>()); | ||
|
||
Type srcCOOTp = getCOOFromTypeWithOrdering( | ||
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false); | ||
|
||
// Clones the original operation but changing the output to an unordered COO. | ||
Operation *cloned = rewriter.clone(*op.getOperation()); | ||
rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() { | ||
cloned->getOpResult(0).setType(srcCOOTp); | ||
}); | ||
Value srcCOO = cloned->getOpResult(0); | ||
|
||
// -> sort | ||
Type dstCOOTp = getCOOFromTypeWithOrdering( | ||
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true); | ||
Value dstCOO = rewriter.create<ReorderCOOOp>( | ||
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort); | ||
|
||
// -> dest. | ||
if (dstCOO.getType() == finalTp) { | ||
rewriter.replaceOp(op, dstCOO); | ||
} else { | ||
// Need an extra conversion if the target type is not COO. | ||
rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO); | ||
} | ||
// TODO: deallocate extra COOs, we should probably delegate it to buffer | ||
// deallocation pass. | ||
return success(); | ||
} |
Oops, something went wrong.