diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index ad399f57f72cb..a131f30976661 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -28,6 +28,7 @@ #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -142,6 +143,8 @@ struct LinalgOpInstancePromotionOptions { const LinalgPromotionOptions &options); /// SubViews to promote. MapVector subViews; + /// Subviews operand numbers to copy in using copyInFn. + llvm::SmallSet operandsNumbersToCopyIn; /// True if the full view should be used for the promoted buffer. DenseMap useFullTileBuffers; @@ -174,6 +177,11 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( Operation *op = opOperand.get().getDefiningOp(); if (auto sv = dyn_cast_or_null(op)) { subViews[operandNumber] = sv; + // In case of linalg generic, copy in only if subview is used in linalg + // payload. + if (!isa(linalgOp) || + linalgOp.payloadUsesValueFromOperand(&opOperand)) + operandsNumbersToCopyIn.insert(operandNumber); useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber]; } } @@ -324,6 +332,8 @@ promoteSubViews(ImplicitLocOpBuilder &b, auto info = promotionInfoMap.find(v.first); if (info == promotionInfoMap.end()) continue; + if (options.operandsNumbersToCopyIn.count(v.first) == 0) + continue; if (failed(options.copyInFn( b, cast(v.second.getDefiningOp()), info->second.partialLocalView))) diff --git a/mlir/test/Dialect/GPU/promotion.mlir b/mlir/test/Dialect/GPU/promotion.mlir index b4668b5678894..2da1be597753b 100644 --- a/mlir/test/Dialect/GPU/promotion.mlir +++ b/mlir/test/Dialect/GPU/promotion.mlir @@ -1,3 +1,4 @@ + // RUN: mlir-opt -allow-unregistered-dialect -pass-pipeline='builtin.module(gpu.module(gpu.func(test-gpu-memory-promotion)))' -split-input-file %s | FileCheck %s gpu.module @foo { diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir index 5cd56db7fd2d8..31b29c0e105d9 100644 --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -353,7 +353,6 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf // CHECK: %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref> to memref, #gpu.address_space> // CHECK: memref.copy %[[VAL_3]], %[[VAL_24]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref, #gpu.address_space> // CHECK: memref.copy %[[VAL_4]], %[[VAL_43]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref, #gpu.address_space> - // CHECK: memref.copy %[[VAL_5]], %[[VAL_62]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref, #gpu.address_space> // CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref, #gpu.address_space>, memref, #gpu.address_space>) outs(%[[VAL_62]] : memref, #gpu.address_space>) { // CHECK: ^bb0(%[[VAL_63:.*]]: f32, %[[VAL_64:.*]]: f32, %[[VAL_65:.*]]: f32): // CHECK: %[[VAL_66:.*]] = arith.addf %[[VAL_63]], %[[VAL_64]] : f32