Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] avoid tensor to memref conversion in sparse tensor rewri… #69362

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 41 additions & 66 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
}
};

// A trivial wrapper to help generate different operations for dense/sparse
// tensors.
struct TensorLike {
TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
ValueRange sizes)
: isSparse(rtt.getEncoding() != nullptr) {
ValueRange sizes) {
SmallVector<Value> dynSzs;
getDynamicSizes(rtt, sizes, dynSzs);

if (isSparse)
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
else
val = allocDenseTensor(builder, loc, rtt, sizes);
};

void insertOrStore(OpBuilder &builder, Location loc, Value v,
ValueRange crds) {
if (isSparse)
val = builder.create<InsertOp>(loc, v, val, crds);
else
builder.create<memref::StoreOp>(loc, v, val, crds);
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
if (!isSparse()) {
Value c0 = constantZero(builder, loc, rtt.getElementType());
val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
}
}

Value getSSA() const {
// We don't need to maintain the SSA chain for a memref value.
return isSparse ? val : nullptr;
void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
// TODO: Unify these two.
if (isSparse())
val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
else
val = builder.create<tensor::InsertOp>(loc, v, val, crds);
}

Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
if (isSparse)
if (isSparse())
return builder.create<LoadOp>(loc, val, true);
return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
return val;
}

void updateSSA(Value v) {
// Dense memref is a non-SSA value.
assert(isSparse);
val = v;
bool isSparse() const {
return getSparseTensorEncoding(val.getType()) != nullptr;
}

private:
bool isSparse;
Value val; // either a memref (for dense tensor) or a sparse tensor.
Value val;
};

struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
Expand Down Expand Up @@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {

TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
Value offset = constantIndex(rewriter, loc, 0);
Value iterArg = dstBuf.getSSA();
Value iterArg = dstBuf.val;

ForeachOp foreachOp;
for (Value input : op.getInputs()) {
// Builds a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
loc, input, iterArg,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
SmallVector<Value> dstLcvs(dstTp.getLvlRank());
Expand All @@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// FIXME: `toStoredDim` is deprecated
dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
}

if (!reduc.empty())
dstBuf.updateSSA(reduc.front());

// Enters foreach, updates the SSA chain.
dstBuf.val = reduc.front();
if (!dstTp.isAllDense()) {
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
builder.create<scf::YieldOp>(loc, dstBuf.val);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insertOrStore(builder, loc, v, dstLcvs);
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
dstBuf.insert(builder, loc, v, dstLcvs);
builder.create<scf::YieldOp>(loc, dstBuf.val);

// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
assert(!reduc.empty());
dstBuf.updateSSA(ifOp.getResult(0));
dstBuf.val = ifOp.getResult(0);
} else {
dstBuf.insertOrStore(builder, loc, v, dstLcvs);
dstBuf.insert(builder, loc, v, dstLcvs);
}
if (reduc.empty())
builder.create<sparse_tensor::YieldOp>(loc);
else
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
Expand All @@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
offset = rewriter.create<arith::AddIOp>(
loc, offset, constantIndex(rewriter, loc, *sh));

if (!foreachOp.getResults().empty()) {
iterArg = foreachOp.getResult(0);
dstBuf.updateSSA(iterArg);
}
iterArg = foreachOp.getResult(0);
dstBuf.val = iterArg;
}

if (!foreachOp.getResults().empty())
dstBuf.updateSSA(iterArg);

dstBuf.val = iterArg;
Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
rewriter.replaceOp(op, ret);
return success();
Expand Down Expand Up @@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
ValueRange vs;
TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);

Value iterArg = dstBuf.getSSA();
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
loc, src, dstBuf.val, foreachOrder,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
// Enters the loop, update the SSA value for insertion chain.
if (!reduc.empty())
dstBuf.updateSSA(reduc.front());

dstBuf.val = reduc.front();
const Dimension dimRank = dstStt.getDimRank();
const Level lvlRank = dstStt.getLvlRank();
SmallVector<Value> lcvs(lvlRank);
Expand All @@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
}

if (!skipZeroCheck) {
assert(!reduc.empty());
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
builder.create<scf::YieldOp>(loc, dstBuf.val);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insertOrStore(builder, loc, v, lcvs);
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
dstBuf.insert(builder, loc, v, lcvs);
builder.create<scf::YieldOp>(loc, dstBuf.val);

// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
dstBuf.updateSSA(ifOp.getResult(0));
dstBuf.val = ifOp.getResult(0);
} else {
dstBuf.insertOrStore(builder, loc, v, lcvs);
dstBuf.insert(builder, loc, v, lcvs);
}
if (reduc.empty())
builder.create<sparse_tensor::YieldOp>(loc);
else
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});

rewriter.setInsertionPointAfter(foreachOp);

// Exits the for loop, links the SSA chain.
if (!foreachOp.getResults().empty())
dstBuf.updateSSA(foreachOp.getResult(0));
dstBuf.val = foreachOp.getResult(0);

Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
rewriter.replaceOp(op, ret);
Expand Down
35 changes: 14 additions & 21 deletions mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,83 +14,76 @@

// CHECK-LABEL: func.func @sparse_convert_1d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
return %0 : tensor<13xi32>
}

// CHECK-LABEL: func.func @sparse_convert_1d_dyn
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
return %0 : tensor<?xi32>
}

// CHECK-LABEL: func.func @sparse_convert_2d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
return %0 : tensor<2x4xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
return %0 : tensor<?x4xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn1
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
return %0 : tensor<2x?xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn2
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
return %0 : tensor<?x?xf64>
}

// CHECK-LABEL: func.func @sparse_convert_3d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
return %0 : tensor<2x3x4xf64>
Expand Down
Loading