Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] implementating stageSparseOpPass as an interface #69022

Merged
merged 9 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td)
mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls)
mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRSparseTensorTypesIncGen)

set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td)
mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen)
add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen)
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- SparseTensorInterfaces.h - sparse tensor operations
//interfaces-------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_

#include "mlir/IR/OpDefinition.h"

namespace mlir {
class PatternRewriter;

namespace sparse_tensor {
class StageWithSortSparseOp;

namespace detail {
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
PatternRewriter &rewriter);
} // namespace detail
} // namespace sparse_tensor
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc"

#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES
#define SPARSETENSOR_IR_SPARSETENSORINTERFACES

include "mlir/IR/OpBase.td"

def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
let description = [{
A stage-with-sort sparse tensor operation is an operation that produces
unordered intermediate output. An extra sort is required to obtain the final
ordered result.

E.g., convert csr -> csc need to be implemented as
convert csr -> unordered coo -> sort by column -> csc; and
concatenate csr, csc -> csr can be staged into
concatenate csr, csr -> unordered coo -> sort by row -> csr.
}];
let cppNamespace = "::mlir::sparse_tensor";
let methods = [
InterfaceMethod<
/*desc=*/"Return true if the operation needs an extra sort to produce the final result.",
/*retTy=*/"bool",
/*methodName=*/"needsExtraSort",
/*args=*/(ins),
/*methodBody=*/"">,
InterfaceMethod<
/*desc=*/"Stage the operation, return the final result value after staging.",
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"stageWithSort",
/*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
/*methodBody=*/[{
return detail::stageWithSortImpl($_op, rewriter);
}]>,
];
}


#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
18 changes: 13 additions & 5 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

Expand Down Expand Up @@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
}

def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
[Pure]>,
[Pure, StageWithSortSparseOpInterface]>,
Arguments<(ins AnyTensor:$source)>,
Results<(outs AnyTensor:$dest)> {
string summary = "Converts between different tensor types";
Expand Down Expand Up @@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}];

let extraClassDeclaration = [{
// Whether the convert can be done by a single step (either a sort or a foreach),
// or it would require a tmp buffer (sort, then foreach).
bool directConvertable();
// Whether the convert can be done by a single step or it would require
// an extra sort. Inherited from StageWithSortSparseOpInterface.
bool needsExtraSort();
}];

let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
Expand Down Expand Up @@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure]
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
}

def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate",
[Pure, StageWithSortSparseOpInterface]>,
Arguments<(ins Variadic<AnyRankedTensor>:$inputs, DimensionAttr:$dimension)>,
Results<(outs AnyRankedTensor:$result)> {

Expand All @@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
```
}];

let extraClassDeclaration = [{
// Whether the concatenate can be done by a single step or it would require
// an extra sort. Inherited from StageWithSortSparseOpInterface.
bool needsExtraSort();
}];

let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
let hasVerifier = 1;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ endif()

add_mlir_dialect_library(MLIRSparseTensorDialect
SparseTensorDialect.cpp
SparseTensorInterfaces.cpp
Detail/Var.cpp
Detail/DimLvlMap.cpp
Detail/LvlTypeParser.cpp
Expand Down
31 changes: 24 additions & 7 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,18 +1065,18 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
return {};
}

bool ConvertOp::directConvertable() {
bool ConvertOp::needsExtraSort() {
SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());

// We can always directly convert to unordered sparse tensor or dense tensor
// since dense tensor support random access.
// We do not need an extra sort when returning unordered sparse tensors or
// dense tensor since dense tensor support random access.
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
return true;
return false;

if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
srcStt.hasSameDimToLvl(dstStt)) {
return true;
return false;
}

// Source and dest tensors are ordered in different ways. We only do direct
Expand All @@ -1086,9 +1086,9 @@ bool ConvertOp::directConvertable() {
// performance.
if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
if (isa<SparseElementsAttr>(constOp.getValue()))
return true;
return false;

return false;
return true;
}

LogicalResult ToPositionsOp::verify() {
Expand Down Expand Up @@ -1248,6 +1248,23 @@ LogicalResult UnaryOp::verify() {
return success();
}

bool ConcatenateOp::needsExtraSort() {
SparseTensorType dstStt = getSparseTensorType(*this);
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
return false;

bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
return getSparseTensorType(op).hasSameDimToLvl(dstStt);
});
// TODO: When conDim != 0, as long as conDim corresponding to the first level
// in all input/output buffers, and all input/output buffers have the same
// dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
// CSC matrices along column).
bool directLowerable =
allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
return !directLowerable;
}

LogicalResult ConcatenateOp::verify() {
const auto dstTp = getSparseTensorType(*this);
const Dimension concatDim = getDimension();
Expand Down
55 changes: 55 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::sparse_tensor;

#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"

LogicalResult
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
PatternRewriter &rewriter) {
if (!op.needsExtraSort())
return failure();

Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());

Type srcCOOTp = getCOOFromTypeWithOrdering(
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);

// Clones the original operation but changing the output to an unordered COO.
Operation *cloned = rewriter.clone(*op.getOperation());
rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
cloned->getOpResult(0).setType(srcCOOTp);
});
Value srcCOO = cloned->getOpResult(0);

// -> sort
Type dstCOOTp = getCOOFromTypeWithOrdering(
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
Value dstCOO = rewriter.create<ReorderCOOOp>(
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);

// -> dest.
if (dstCOO.getType() == finalTp) {
rewriter.replaceOp(op, dstCOO);
} else {
// Need an extra conversion if the target type is not COO.
rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
}
// TODO: deallocate extra COOs, we should probably delegate it to buffer
// deallocation pass.
return success();
}
Loading