diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 104d6ae1f9f6b..ba41904b37099 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -554,13 +554,14 @@ struct UnPackOpTiling sliceSrcIndices.append(numInnerTiles, zeroAttr); sliceSrcSizes.append(unpackOp.getMixedTiles()); sliceSrcStrides.append(numInnerTiles, oneAttr); - Value sliceSource = + SmallVector generatedSlices; + ExtractSliceOp sliceSource = b.create(loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes, sliceSrcStrides); + generatedSlices.push_back(sliceSource); SmallVector destStrides(destRank, oneAttr); Value sliceDest; - SmallVector generatedSlices; if (isPerfectTilingCase) { auto destSliceOp = b.create(loc, unpackOp.getDest(), offsets, sizes, destStrides); @@ -571,7 +572,7 @@ struct UnPackOpTiling unpackOp.getDestType().getElementType()); } - SmallVector tiledOperands = {sliceSource, sliceDest}; + SmallVector tiledOperands = {sliceSource.getResult(), sliceDest}; for (auto tile : unpackOp.getInnerTiles()) tiledOperands.push_back(tile); @@ -586,7 +587,6 @@ struct UnPackOpTiling auto extractSlice = b.create(loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, destStrides); - generatedSlices.push_back(extractSlice); return TilingResult{ {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices}; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index 3ea1929e4ed78..5f7663af773a4 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -587,3 +587,50 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]] // CHECK: scf.yield %[[INSERT_SLICE]] // CHECK: return %[[FOR_RESULT]] + +// ----- + +func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> { + %0 = tensor.unpack %source + outer_dims_perm = [0, 1, 2] + inner_dims_pos = [1, 2] + inner_tiles = [8, 4] into %dest + : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32> + %1 = tensor.empty() : tensor<1x2x1152xf32> + %cst = arith.constant 1.0 : f32 + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0 : tensor<1x2x1152xf32>) + outs(%1 : tensor<1x2x1152xf32>) { + ^bb0(%in: f32, %out: f32): + %7 = arith.addf %in, %cst : f32 + linalg.yield %7 : f32 + } -> tensor<1x2x1152xf32> + return %2 : tensor<1x2x1152xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.structured.fuse %matmul [0, 1, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func @imperfect_unpack_producer_fusion +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x288x8x4xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x2x1152xf32> +// CHECK: %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}}) +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SLICE]] +// CHECK-DAG: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]] +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[UNPACK_SLICE]] +// CHECK-SAME: outs(%[[INIT_SLICE]] +// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]] +// CHECK: scf.yield %[[INSERT_SLICE]] +// CHECK: return %[[FOR_RESULT]]