Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] avoid tensor to memref conversion in sparse tensor rewri… #69362

Merged
merged 1 commit into from
Oct 17, 2023

Conversation

PeimingLiu
Copy link
Member

…ting rules.

@PeimingLiu PeimingLiu requested review from aartbik and yinying-lisa-li and removed request for aartbik October 17, 2023 18:00
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Oct 17, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

…ting rules.


Patch is 27.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69362.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+41-66)
  • (modified) mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir (+14-21)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_concat.mlir (+77-71)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1bfee3aa1d7ee8e..e50b14975e83d63 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
   }
 };
 
+// A trivial wrapper to help generate different operations for dense/sparse
+// tensors.
 struct TensorLike {
   TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
-             ValueRange sizes)
-      : isSparse(rtt.getEncoding() != nullptr) {
+             ValueRange sizes) {
     SmallVector<Value> dynSzs;
     getDynamicSizes(rtt, sizes, dynSzs);
 
-    if (isSparse)
-      val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
-    else
-      val = allocDenseTensor(builder, loc, rtt, sizes);
-  };
-
-  void insertOrStore(OpBuilder &builder, Location loc, Value v,
-                     ValueRange crds) {
-    if (isSparse)
-      val = builder.create<InsertOp>(loc, v, val, crds);
-    else
-      builder.create<memref::StoreOp>(loc, v, val, crds);
+    val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
+    if (!isSparse()) {
+      Value c0 = constantZero(builder, loc, rtt.getElementType());
+      val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
+    }
   }
 
-  Value getSSA() const {
-    // We don't need to maintain the SSA chain for a memref value.
-    return isSparse ? val : nullptr;
+  void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
+    // TODO: Unify these two.
+    if (isSparse())
+      val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
+    else
+      val = builder.create<tensor::InsertOp>(loc, v, val, crds);
   }
 
   Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
-    if (isSparse)
+    if (isSparse())
       return builder.create<LoadOp>(loc, val, true);
-    return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
+    return val;
   }
 
-  void updateSSA(Value v) {
-    // Dense memref is a non-SSA value.
-    assert(isSparse);
-    val = v;
+  bool isSparse() const {
+    return getSparseTensorEncoding(val.getType()) != nullptr;
   }
 
-private:
-  bool isSparse;
-  Value val; // either a memref (for dense tensor) or a sparse tensor.
+  Value val;
 };
 
 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
@@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
 
     TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
     Value offset = constantIndex(rewriter, loc, 0);
-    Value iterArg = dstBuf.getSSA();
+    Value iterArg = dstBuf.val;
 
     ForeachOp foreachOp;
     for (Value input : op.getInputs()) {
       // Builds a for op for each input tensor to append new values into the
       // output tensor.
       foreachOp = rewriter.create<ForeachOp>(
-          loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
+          loc, input, iterArg,
           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
               ValueRange reduc) {
             SmallVector<Value> dstLcvs(dstTp.getLvlRank());
@@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
               // FIXME: `toStoredDim` is deprecated
               dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
             }
-
-            if (!reduc.empty())
-              dstBuf.updateSSA(reduc.front());
-
+            // Enters foreach, updates the SSA chain.
+            dstBuf.val = reduc.front();
             if (!dstTp.isAllDense()) {
               Value cond = genIsNonzero(builder, loc, v);
               auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
                                                     /*else*/ true);
               builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+              builder.create<scf::YieldOp>(loc, dstBuf.val);
 
               builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
-              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+              dstBuf.insert(builder, loc, v, dstLcvs);
+              builder.create<scf::YieldOp>(loc, dstBuf.val);
 
               // Exits the ifOp, update the sparse tensor SSA value.
               builder.setInsertionPointAfter(ifOp);
-              assert(!reduc.empty());
-              dstBuf.updateSSA(ifOp.getResult(0));
+              dstBuf.val = ifOp.getResult(0);
             } else {
-              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
+              dstBuf.insert(builder, loc, v, dstLcvs);
             }
-            if (reduc.empty())
-              builder.create<sparse_tensor::YieldOp>(loc);
-            else
-              builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
+            builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
           });
       // Accumulates the offset. Note that only static-shaped inputs are allowed
       // by concatenate op verifier, which saves us from computing the offset
@@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       offset = rewriter.create<arith::AddIOp>(
           loc, offset, constantIndex(rewriter, loc, *sh));
 
-      if (!foreachOp.getResults().empty()) {
-        iterArg = foreachOp.getResult(0);
-        dstBuf.updateSSA(iterArg);
-      }
+      iterArg = foreachOp.getResult(0);
+      dstBuf.val = iterArg;
     }
 
-    if (!foreachOp.getResults().empty())
-      dstBuf.updateSSA(iterArg);
-
+    dstBuf.val = iterArg;
     Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
     rewriter.replaceOp(op, ret);
     return success();
@@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
     ValueRange vs;
     TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
 
-    Value iterArg = dstBuf.getSSA();
     auto foreachOp = rewriter.create<ForeachOp>(
-        loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
+        loc, src, dstBuf.val, foreachOrder,
         [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
             ValueRange reduc) {
           // Enters the loop, update the SSA value for insertion chain.
-          if (!reduc.empty())
-            dstBuf.updateSSA(reduc.front());
-
+          dstBuf.val = reduc.front();
           const Dimension dimRank = dstStt.getDimRank();
           const Level lvlRank = dstStt.getLvlRank();
           SmallVector<Value> lcvs(lvlRank);
@@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
           }
 
           if (!skipZeroCheck) {
-            assert(!reduc.empty());
             Value cond = genIsNonzero(builder, loc, v);
             auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
                                                   /*else*/ true);
             builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-            builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+            builder.create<scf::YieldOp>(loc, dstBuf.val);
 
             builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-            dstBuf.insertOrStore(builder, loc, v, lcvs);
-            builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+            dstBuf.insert(builder, loc, v, lcvs);
+            builder.create<scf::YieldOp>(loc, dstBuf.val);
 
             // Exits the ifOp, update the sparse tensor SSA value.
             builder.setInsertionPointAfter(ifOp);
-            dstBuf.updateSSA(ifOp.getResult(0));
+            dstBuf.val = ifOp.getResult(0);
           } else {
-            dstBuf.insertOrStore(builder, loc, v, lcvs);
+            dstBuf.insert(builder, loc, v, lcvs);
           }
-          if (reduc.empty())
-            builder.create<sparse_tensor::YieldOp>(loc);
-          else
-            builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
+          builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
         });
 
     rewriter.setInsertionPointAfter(foreachOp);
 
     // Exits the for loop, links the SSA chain.
-    if (!foreachOp.getResults().empty())
-      dstBuf.updateSSA(foreachOp.getResult(0));
+    dstBuf.val = foreachOp.getResult(0);
 
     Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
     rewriter.replaceOp(op, ret);
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index c22f051a0d5854d..e2dcb068e11851e 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -14,11 +14,10 @@
 
 // CHECK-LABEL:  func.func @sparse_convert_1d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
   %0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
   return %0 : tensor<13xi32>
@@ -26,11 +25,10 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
 
 // CHECK-LABEL:  func.func @sparse_convert_1d_dyn
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
   return %0 : tensor<?xi32>
@@ -38,11 +36,10 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
 
 // CHECK-LABEL:  func.func @sparse_convert_2d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
   return %0 : tensor<2x4xf64>
@@ -50,11 +47,10 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
   return %0 : tensor<?x4xf64>
@@ -62,11 +58,10 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn1
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
   return %0 : tensor<2x?xf64>
@@ -74,11 +69,10 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_2d_dyn2
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
   return %0 : tensor<?x?xf64>
@@ -86,11 +80,10 @@ func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tens
 
 // CHECK-LABEL:  func.func @sparse_convert_3d
 // CHECK-NOT:      sparse_tensor.reorder_coo
-// CHECK:          memref.alloc
+// CHECK:          bufferization.alloc_tensor
 // CHECK:          linalg.fill
 // CHECK:          sparse_tensor.foreach
-// CHECK:            memref.store
-// CHECK:          bufferization.to_tensor
+// CHECK:            tensor.insert
 func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
   return %0 : tensor<2x3x4xf64>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index bdfab54dc6daeb5..f3d3dd28563e891 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -176,77 +176,83 @@ func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
     return %0 : tensor<?x?xf64, #DCSR>
 }
 
-// CHECK-LABEL: @concat_sparse_sparse_dense(
-//  CHECK-SAME:  %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor
-//  CHECK-SAME:  %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor
-//  CHECK-SAME:  %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor
-//   CHECK-DAG:  %[[TMP_c0:.*]] = arith.constant 0 : index
-//   CHECK-DAG:  %[[TMP_c1:.*]] = arith.constant 1 : index
-//   CHECK-DAG:  %[[TMP_c5:.*]] = arith.constant 5 : index
-//   CHECK-DAG:  %[[TMP_c2:.*]] = arith.constant 2 : index
-//   CHECK-DAG:  %[[TMP_c9:.*]] = arith.constant 9 : index
-//   CHECK-DAG:  %[[TMP_c4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:  %[[TMP_d0:.*]] = arith.constant 0.000000e+00 : f64
-//       CHECK:  %[[A:.*]] = memref.alloc(%[[TMP_c9]], %[[TMP_c4]]) : memref<?x?xf64>
-//       CHECK:  linalg.fill ins(%[[TMP_d0]] : f64) outs(%[[A]] : memref<?x?xf64>)
-//       CHECK:  %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_4:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 1 : index} : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] : tensor<2x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_23]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[TMP_8:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_9:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_10:.*]] = sparse_tensor.positions %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_11:.*]] = sparse_tensor.coordinates %[[TMP_arg1]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_12:.*]] = sparse_tensor.values %[[TMP_arg1]] : tensor<3x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_13:.*]] = memref.load %[[TMP_8]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref<?xindex>
-//   CHECK-DAG:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[TMP_15:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_16:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 0 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_17:.*]] = sparse_tensor.positions %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_18:.*]] = sparse_tensor.coordinates %[[TMP_arg2]] {level = 1 : index} : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_19:.*]] = sparse_tensor.values %[[TMP_arg2]] : tensor<4x4xf64, #sparse_tensor
-//       CHECK:  %[[TMP_20:.*]] = memref.load %[[TMP_15]][%[[TMP_c0]]] : memref<?xindex>
-//       CHECK:  %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref<?xindex>
-//       CHECK:  scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]]
-//       CHECK:    %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xindex>
-//       CHECK:    %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref<?xindex>
-//       CHECK:    %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-//       CHECK:    %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref<?xindex>
-//       CHECK:    scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]]
-//       CHECK:      %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
-//       CHECK:      %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xf64>
-//       CHECK:      %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index
-//       CHECK:      memref.store %[[TMP_28]], %[[A]]{{\[}}%[[TMP_29]], %[[TMP_27]]] : memref<?x?xf64>
-//       CHECK:    }
-//       CHECK:  }
-//       CHECK:  %[[R:.*]] = bufferization.to_tensor %[[A]] : memref<?x?xf64>
-//       CHECK:  return %[[R]] : tensor<?x?xf64>
+// CHECK-LABEL:   func.func @concat_sparse_sparse_dense(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<2x4xf64, #sparse_tensor
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x4xf64, #sparse_tensor
+// CHECK-SAME:      %[[VAL_2...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants