Skip to content

Commit

Permalink
[mlir][sparse] implementating stageSparseOpPass as an interface (#69022)
Browse files Browse the repository at this point in the history
  • Loading branch information
PeimingLiu authored Oct 17, 2023
1 parent a22a1fe commit 761c9dd
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 197 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td)
mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls)
mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRSparseTensorTypesIncGen)

set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td)
mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen)
add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen)
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
31 changes: 31 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- SparseTensorInterfaces.h - sparse tensor operations
//interfaces-------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_

#include "mlir/IR/OpDefinition.h"

namespace mlir {
class PatternRewriter;

namespace sparse_tensor {
class StageWithSortSparseOp;

namespace detail {
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
PatternRewriter &rewriter);
} // namespace detail
} // namespace sparse_tensor
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc"

#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES
#define SPARSETENSOR_IR_SPARSETENSORINTERFACES

include "mlir/IR/OpBase.td"

def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
let description = [{
A stage-with-sort sparse tensor operation is an operation that produces
unordered intermediate output. An extra sort is required to obtain the final
ordered result.

E.g., convert csr -> csc need to be implemented as
convert csr -> unordered coo -> sort by column -> csc; and
concatenate csr, csc -> csr can be staged into
concatenate csr, csr -> unordered coo -> sort by row -> csr.
}];
let cppNamespace = "::mlir::sparse_tensor";
let methods = [
InterfaceMethod<
/*desc=*/"Return true if the operation needs an extra sort to produce the final result.",
/*retTy=*/"bool",
/*methodName=*/"needsExtraSort",
/*args=*/(ins),
/*methodBody=*/"">,
InterfaceMethod<
/*desc=*/"Stage the operation, return the final result value after staging.",
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"stageWithSort",
/*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
/*methodBody=*/[{
return detail::stageWithSortImpl($_op, rewriter);
}]>,
];
}


#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
18 changes: 13 additions & 5 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

Expand Down Expand Up @@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
}

def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
[Pure]>,
[Pure, StageWithSortSparseOpInterface]>,
Arguments<(ins AnyTensor:$source)>,
Results<(outs AnyTensor:$dest)> {
string summary = "Converts between different tensor types";
Expand Down Expand Up @@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}];

let extraClassDeclaration = [{
// Whether the convert can be done by a single step (either a sort or a foreach),
// or it would require a tmp buffer (sort, then foreach).
bool directConvertable();
// Whether the convert can be done by a single step or it would require
// an extra sort. Inherited from StageWithSortSparseOpInterface.
bool needsExtraSort();
}];

let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
Expand Down Expand Up @@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure]
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
}

def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate",
[Pure, StageWithSortSparseOpInterface]>,
Arguments<(ins Variadic<AnyRankedTensor>:$inputs, DimensionAttr:$dimension)>,
Results<(outs AnyRankedTensor:$result)> {

Expand All @@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
```
}];

let extraClassDeclaration = [{
// Whether the concatenate can be done by a single step or it would require
// an extra sort. Inherited from StageWithSortSparseOpInterface.
bool needsExtraSort();
}];

let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
let hasVerifier = 1;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ endif()

add_mlir_dialect_library(MLIRSparseTensorDialect
SparseTensorDialect.cpp
SparseTensorInterfaces.cpp
Detail/Var.cpp
Detail/DimLvlMap.cpp
Detail/LvlTypeParser.cpp
Expand Down
31 changes: 24 additions & 7 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,18 +1065,18 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
return {};
}

bool ConvertOp::directConvertable() {
bool ConvertOp::needsExtraSort() {
SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());

// We can always directly convert to unordered sparse tensor or dense tensor
// since dense tensor support random access.
// We do not need an extra sort when returning unordered sparse tensors or
// dense tensor since dense tensor support random access.
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
return true;
return false;

if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
srcStt.hasSameDimToLvl(dstStt)) {
return true;
return false;
}

// Source and dest tensors are ordered in different ways. We only do direct
Expand All @@ -1086,9 +1086,9 @@ bool ConvertOp::directConvertable() {
// performance.
if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
if (isa<SparseElementsAttr>(constOp.getValue()))
return true;
return false;

return false;
return true;
}

LogicalResult ToPositionsOp::verify() {
Expand Down Expand Up @@ -1248,6 +1248,23 @@ LogicalResult UnaryOp::verify() {
return success();
}

bool ConcatenateOp::needsExtraSort() {
SparseTensorType dstStt = getSparseTensorType(*this);
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
return false;

bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
return getSparseTensorType(op).hasSameDimToLvl(dstStt);
});
// TODO: When conDim != 0, as long as conDim corresponding to the first level
// in all input/output buffers, and all input/output buffers have the same
// dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
// CSC matrices along column).
bool directLowerable =
allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
return !directLowerable;
}

LogicalResult ConcatenateOp::verify() {
const auto dstTp = getSparseTensorType(*this);
const Dimension concatDim = getDimension();
Expand Down
55 changes: 55 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::sparse_tensor;

#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"

LogicalResult
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
PatternRewriter &rewriter) {
if (!op.needsExtraSort())
return failure();

Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());

Type srcCOOTp = getCOOFromTypeWithOrdering(
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);

// Clones the original operation but changing the output to an unordered COO.
Operation *cloned = rewriter.clone(*op.getOperation());
rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
cloned->getOpResult(0).setType(srcCOOTp);
});
Value srcCOO = cloned->getOpResult(0);

// -> sort
Type dstCOOTp = getCOOFromTypeWithOrdering(
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
Value dstCOO = rewriter.create<ReorderCOOOp>(
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);

// -> dest.
if (dstCOO.getType() == finalTp) {
rewriter.replaceOp(op, dstCOO);
} else {
// Need an extra conversion if the target type is not COO.
rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
}
// TODO: deallocate extra COOs, we should probably delegate it to buffer
// deallocation pass.
return success();
}
Loading

0 comments on commit 761c9dd

Please sign in to comment.