Skip to content

Commit

Permalink
[mlir][sparse] avoid tensor to memref conversion in sparse tensor rew…
Browse files Browse the repository at this point in the history
…ri… (#69362)

…ting rules.
  • Loading branch information
PeimingLiu authored Oct 17, 2023
1 parent fd31112 commit 71c97c7
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 158 deletions.
107 changes: 41 additions & 66 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
}
};

// A trivial wrapper to help generate different operations for dense/sparse
// tensors.
struct TensorLike {
TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
ValueRange sizes)
: isSparse(rtt.getEncoding() != nullptr) {
ValueRange sizes) {
SmallVector<Value> dynSzs;
getDynamicSizes(rtt, sizes, dynSzs);

if (isSparse)
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
else
val = allocDenseTensor(builder, loc, rtt, sizes);
};

void insertOrStore(OpBuilder &builder, Location loc, Value v,
ValueRange crds) {
if (isSparse)
val = builder.create<InsertOp>(loc, v, val, crds);
else
builder.create<memref::StoreOp>(loc, v, val, crds);
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
if (!isSparse()) {
Value c0 = constantZero(builder, loc, rtt.getElementType());
val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
}
}

Value getSSA() const {
// We don't need to maintain the SSA chain for a memref value.
return isSparse ? val : nullptr;
void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
// TODO: Unify these two.
if (isSparse())
val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
else
val = builder.create<tensor::InsertOp>(loc, v, val, crds);
}

Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
if (isSparse)
if (isSparse())
return builder.create<LoadOp>(loc, val, true);
return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
return val;
}

void updateSSA(Value v) {
// Dense memref is a non-SSA value.
assert(isSparse);
val = v;
bool isSparse() const {
return getSparseTensorEncoding(val.getType()) != nullptr;
}

private:
bool isSparse;
Value val; // either a memref (for dense tensor) or a sparse tensor.
Value val;
};

struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
Expand Down Expand Up @@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {

TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
Value offset = constantIndex(rewriter, loc, 0);
Value iterArg = dstBuf.getSSA();
Value iterArg = dstBuf.val;

ForeachOp foreachOp;
for (Value input : op.getInputs()) {
// Builds a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
loc, input, iterArg,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
SmallVector<Value> dstLcvs(dstTp.getLvlRank());
Expand All @@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// FIXME: `toStoredDim` is deprecated
dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
}

if (!reduc.empty())
dstBuf.updateSSA(reduc.front());

// Enters foreach, updates the SSA chain.
dstBuf.val = reduc.front();
if (!dstTp.isAllDense()) {
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
builder.create<scf::YieldOp>(loc, dstBuf.val);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insertOrStore(builder, loc, v, dstLcvs);
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
dstBuf.insert(builder, loc, v, dstLcvs);
builder.create<scf::YieldOp>(loc, dstBuf.val);

// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
assert(!reduc.empty());
dstBuf.updateSSA(ifOp.getResult(0));
dstBuf.val = ifOp.getResult(0);
} else {
dstBuf.insertOrStore(builder, loc, v, dstLcvs);
dstBuf.insert(builder, loc, v, dstLcvs);
}
if (reduc.empty())
builder.create<sparse_tensor::YieldOp>(loc);
else
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
Expand All @@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
offset = rewriter.create<arith::AddIOp>(
loc, offset, constantIndex(rewriter, loc, *sh));

if (!foreachOp.getResults().empty()) {
iterArg = foreachOp.getResult(0);
dstBuf.updateSSA(iterArg);
}
iterArg = foreachOp.getResult(0);
dstBuf.val = iterArg;
}

if (!foreachOp.getResults().empty())
dstBuf.updateSSA(iterArg);

dstBuf.val = iterArg;
Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
rewriter.replaceOp(op, ret);
return success();
Expand Down Expand Up @@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
ValueRange vs;
TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);

Value iterArg = dstBuf.getSSA();
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
loc, src, dstBuf.val, foreachOrder,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
// Enters the loop, update the SSA value for insertion chain.
if (!reduc.empty())
dstBuf.updateSSA(reduc.front());

dstBuf.val = reduc.front();
const Dimension dimRank = dstStt.getDimRank();
const Level lvlRank = dstStt.getLvlRank();
SmallVector<Value> lcvs(lvlRank);
Expand All @@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
}

if (!skipZeroCheck) {
assert(!reduc.empty());
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
builder.create<scf::YieldOp>(loc, dstBuf.val);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insertOrStore(builder, loc, v, lcvs);
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
dstBuf.insert(builder, loc, v, lcvs);
builder.create<scf::YieldOp>(loc, dstBuf.val);

// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
dstBuf.updateSSA(ifOp.getResult(0));
dstBuf.val = ifOp.getResult(0);
} else {
dstBuf.insertOrStore(builder, loc, v, lcvs);
dstBuf.insert(builder, loc, v, lcvs);
}
if (reduc.empty())
builder.create<sparse_tensor::YieldOp>(loc);
else
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});

rewriter.setInsertionPointAfter(foreachOp);

// Exits the for loop, links the SSA chain.
if (!foreachOp.getResults().empty())
dstBuf.updateSSA(foreachOp.getResult(0));
dstBuf.val = foreachOp.getResult(0);

Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
rewriter.replaceOp(op, ret);
Expand Down
35 changes: 14 additions & 21 deletions mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,83 +14,76 @@

// CHECK-LABEL: func.func @sparse_convert_1d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
return %0 : tensor<13xi32>
}

// CHECK-LABEL: func.func @sparse_convert_1d_dyn
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
return %0 : tensor<?xi32>
}

// CHECK-LABEL: func.func @sparse_convert_2d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
return %0 : tensor<2x4xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
return %0 : tensor<?x4xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn1
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
return %0 : tensor<2x?xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn2
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
return %0 : tensor<?x?xf64>
}

// CHECK-LABEL: func.func @sparse_convert_3d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
return %0 : tensor<2x3x4xf64>
Expand Down
Loading

0 comments on commit 71c97c7

Please sign in to comment.