Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Fix bug in UnPackOp tiling implementation causing infinite loop #113571

Merged
merged 1 commit into from
Oct 25, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Oct 24, 2024

This fixes a bug in the tiling implementation of tensor.unpack that was causing an infinite loop when certain unpack ops get tiled and fused as a producer. The tiled implementation of tensor.unpack sometimes needs to create an additional tensor.extract_slice on the result of the tiled unpack op, but this slice was getting added to the generatedSlices of the tiling result. The generatedSlices are used to find the next producers to fuse, so it caused an infinite loop of fusing the same unpack op after it was already in the loop. This fixes the bug by adding the slice of the source instead of the result.

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

This fixes a bug in the tiling implementation of tensor.unpack that was causing an infinite loop when certain unpack ops get tiled and fused as a producer. The tiled implementation of tensor.unpack sometimes needs to create an additional tensor.extract_slice on the result of the tiled unpack op, but this slice was getting added to the generatedSlices of the tiling result. The generatedSlices are used to find the next producers to fuse, so it caused an infinite loop of fusing the same unpack op after it was already in the loop. This fixes the bug by adding the slice of the source instead of the result.


Full diff: https://github.com/llvm/llvm-project/pull/113571.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+4-4)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir (+47)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 104d6ae1f9f6b5..ba41904b370991 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -554,13 +554,14 @@ struct UnPackOpTiling
     sliceSrcIndices.append(numInnerTiles, zeroAttr);
     sliceSrcSizes.append(unpackOp.getMixedTiles());
     sliceSrcStrides.append(numInnerTiles, oneAttr);
-    Value sliceSource =
+    SmallVector<Operation *> generatedSlices;
+    ExtractSliceOp sliceSource =
         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
                                  sliceSrcSizes, sliceSrcStrides);
+    generatedSlices.push_back(sliceSource);
 
     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
     Value sliceDest;
-    SmallVector<Operation *> generatedSlices;
     if (isPerfectTilingCase) {
       auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
                                                   offsets, sizes, destStrides);
@@ -571,7 +572,7 @@ struct UnPackOpTiling
                                     unpackOp.getDestType().getElementType());
     }
 
-    SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
+    SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
     for (auto tile : unpackOp.getInnerTiles())
       tiledOperands.push_back(tile);
 
@@ -586,7 +587,6 @@ struct UnPackOpTiling
     auto extractSlice =
         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
                                  resultOffsetsFromDest, sizes, destStrides);
-    generatedSlices.push_back(extractSlice);
     return TilingResult{
         {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
   }
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 3ea1929e4ed785..5f7663af773a4a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -587,3 +587,50 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
 //       CHECK:     scf.yield %[[INSERT_SLICE]]
 //       CHECK:   return %[[FOR_RESULT]]
+
+// -----
+
+func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
+  %0 = tensor.unpack %source
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [8, 4] into %dest
+      : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
+  %1 = tensor.empty() : tensor<1x2x1152xf32>
+  %cst = arith.constant 1.0 : f32
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+                       iterator_types = ["parallel", "parallel", "parallel"]}
+                       ins(%0 : tensor<1x2x1152xf32>)
+                       outs(%1 : tensor<1x2x1152xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %7 = arith.addf %in, %cst : f32
+    linalg.yield %7 : f32
+  } -> tensor<1x2x1152xf32>
+  return %2 : tensor<1x2x1152xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.structured.fuse %matmul [0, 1, 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @imperfect_unpack_producer_fusion
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
+//  CHECK-SAME:     %[[ARG1:.+]]: tensor<1x2x1152xf32>
+//       CHECK:   %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
+//       CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+//       CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
+//   CHECK-DAG:     %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[UNPACK_SLICE]]
+//  CHECK-SAME:         outs(%[[INIT_SLICE]]
+//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
+//       CHECK:     scf.yield %[[INSERT_SLICE]]
+//       CHECK:   return %[[FOR_RESULT]]

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
sliceSrcSizes, sliceSrcStrides);
generatedSlices.push_back(sliceSource);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm realizing that maybe I'm not supposed to add the input operand slices. Should I be doing this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the the doc, we should return the slices for fusion, so it looks like the change makes sense. Why do you think that we should not do this?

/// - `generatedSlices` contains the list of slices that are generated during
///   tiling. These slices can be used for fusing producers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at a few other implementations and they seemed to only add the init argument slices, so I thought there was some context I was missing. I checked some other implementations and they add the input slices as well, though, so I think the ops I looked at happened to be special cases. I think it is good as I had it.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember all the details, and the tiling mechanism has been evolving. Based on the comment of TilingResult, the change looks okay to me.

b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
sliceSrcSizes, sliceSrcStrides);
generatedSlices.push_back(sliceSource);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the the doc, we should return the slices for fusion, so it looks like the change makes sense. Why do you think that we should not do this?

/// - `generatedSlices` contains the list of slices that are generated during
///   tiling. These slices can be used for fusing producers.

@Max191 Max191 merged commit f1595ec into llvm:main Oct 25, 2024
11 checks passed
@frobtech frobtech mentioned this pull request Oct 25, 2024
Max191 added a commit to iree-org/llvm-project that referenced this pull request Oct 29, 2024
llvm#113571)

This fixes a bug in the tiling implementation of tensor.unpack that was
causing an infinite loop when certain unpack ops get tiled and fused as
a producer. The tiled implementation of tensor.unpack sometimes needs to
create an additional tensor.extract_slice on the result of the tiled
unpack op, but this slice was getting added to the `generatedSlices` of
the tiling result. The `generatedSlices` are used to find the next
producers to fuse, so it caused an infinite loop of fusing the same
unpack op after it was already in the loop. This fixes the bug by adding
the slice of the source instead of the result.

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
llvm#113571)

This fixes a bug in the tiling implementation of tensor.unpack that was
causing an infinite loop when certain unpack ops get tiled and fused as
a producer. The tiled implementation of tensor.unpack sometimes needs to
create an additional tensor.extract_slice on the result of the tiled
unpack op, but this slice was getting added to the `generatedSlices` of
the tiling result. The `generatedSlices` are used to find the next
producers to fuse, so it caused an infinite loop of fusing the same
unpack op after it was already in the loop. This fixes the bug by adding
the slice of the source instead of the result.

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants