From 86244de727b36b0f2113bba4df3985237abad68f Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 3 Feb 2025 15:54:53 -0500 Subject: [PATCH] [Encoding][Codegen] Add initial pad encoding layout attrs (#19865) These allow us to pad allocations without changing the logical tensor sizes or data layouts. Split the encoding layout attribute interface into two: * One with target-specific information that allows us to decide layouts. * One with serialized target-agnostic padding information. Signed-off-by: Jakub Kuderski --- compiler/plugins/target/ROCM/ROCMTarget.cpp | 10 ++ compiler/plugins/target/ROCM/test/BUILD.bazel | 1 + .../plugins/target/ROCM/test/CMakeLists.txt | 1 + .../target/ROCM/test/gpu_encoding_attrs.mlir | 26 +++++ .../Dialect/Codegen/IR/IREECodegenAttrs.td | 2 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.td | 27 +++++ .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 16 +++ .../Dialect/GPU/TargetUtils/KnownTargets.h | 4 + .../CPUEncodingExternalModels.cpp | 57 ++++++---- .../GPUEncodingExternalModels.cpp | 106 +++++++++++++++++- .../Codegen/ExternalInterfaces/Interfaces.cpp | 2 +- .../Dialect/Encoding/IR/EncodingAttrs.cpp | 23 ++-- .../Dialect/Encoding/IR/EncodingAttrs.td | 33 +++++- .../Dialect/Encoding/IR/EncodingInterfaces.td | 64 +++++++---- .../Transforms/test/specialize_encodings.mlir | 56 +++++++++ .../compiler/Utils/ElementPackingUtils.cpp | 5 +- 16 files changed, 365 insertions(+), 68 deletions(-) create mode 100644 compiler/plugins/target/ROCM/test/gpu_encoding_attrs.mlir diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index ea05e9fd12e2..b7069712d7d6 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -66,6 +66,7 @@ struct ROCMOptions { std::string bitcodeDirectory = getDefaultBitcodeDirectory(); int wavesPerEu = 0; std::string enableROCMUkernels = "none"; + bool experimentalPadLayout = false; bool slpVectorization = true; bool globalISel = false; @@ -105,6 +106,10 @@ struct ROCMOptions { cl::desc("Enables microkernels in the HIP compiler backend. May be " "`default`, `none`, `all`, or a comma-separated list of " "specific unprefixed microkernels to enable, e.g. `mmt4d`.")); + binder.opt("iree-hip-enable-experimental-pad-layout", + experimentalPadLayout, cl::cat(category), + cl::desc("Enables additional padding on allocations to " + "maximize cache bandwidth.")); binder.list( "iree-hip-pass-plugin-path", passPlugins, @@ -248,6 +253,11 @@ class ROCMTargetBackend final : public TargetBackend { if (auto target = GPU::getHIPTargetDetails( options.target, options.targetFeatures, context)) { addConfig("iree.gpu.target", target); + if (options.experimentalPadLayout) { + if (Attribute encoding = GPU::getHIPTargetEncodingLayoutAttr(target)) { + addConfig("encoding", encoding); + } + } } addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels)); diff --git a/compiler/plugins/target/ROCM/test/BUILD.bazel b/compiler/plugins/target/ROCM/test/BUILD.bazel index 7201e4b988e8..a99cbc4cc828 100644 --- a/compiler/plugins/target/ROCM/test/BUILD.bazel +++ b/compiler/plugins/target/ROCM/test/BUILD.bazel @@ -19,6 +19,7 @@ iree_lit_test_suite( "config_ukernel_argmax_gfx942.mlir", "config_ukernel_multi_mma_gfx942.mlir", "default_tuning_specs_amdgpu.mlir", + "gpu_encoding_attrs.mlir", "lowering_strategy_from_tuning_spec.mlir", "ukernel_pipeline_transform.mlir", ], diff --git a/compiler/plugins/target/ROCM/test/CMakeLists.txt b/compiler/plugins/target/ROCM/test/CMakeLists.txt index fde029c3ce6d..c511ccf2dccd 100644 --- a/compiler/plugins/target/ROCM/test/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/test/CMakeLists.txt @@ -18,6 +18,7 @@ iree_lit_test_suite( "config_ukernel_argmax_gfx942.mlir" "config_ukernel_multi_mma_gfx942.mlir" "default_tuning_specs_amdgpu.mlir" + "gpu_encoding_attrs.mlir" "lowering_strategy_from_tuning_spec.mlir" "ukernel_pipeline_transform.mlir" TOOLS diff --git a/compiler/plugins/target/ROCM/test/gpu_encoding_attrs.mlir b/compiler/plugins/target/ROCM/test/gpu_encoding_attrs.mlir new file mode 100644 index 000000000000..0342bdd488ff --- /dev/null +++ b/compiler/plugins/target/ROCM/test/gpu_encoding_attrs.mlir @@ -0,0 +1,26 @@ +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' \ +// RUN: --iree-hip-target=gfx942 --iree-hip-enable-experimental-pad-layout %s | FileCheck %s --check-prefix=PAD +// +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' \ +// RUN: --iree-hip-target=gfx90a --iree-hip-enable-experimental-pad-layout %s | FileCheck %s --check-prefix=PAD + +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' \ +// RUN: --iree-hip-target=gfx90a --iree-hip-enable-experimental-pad-layout=false %s | FileCheck %s --check-prefix=NOPAD + +// PAD: #hal.executable.target<"rocm" +// PAD-SAME: encoding = #iree_gpu.gpu_pad_layout + +// NOPAD: #hal.executable.target<"rocm" +// NOPAD-NOT: encoding = #iree_gpu.gpu_pad_layout + +stream.executable public @main { + stream.executable.export @main workgroups(%arg0: index) -> (index, index, index) { + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + stream.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @main() { + return + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index e5c6f6f649cd..986f776a867a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -417,7 +417,7 @@ def IREECodegen_ExportConfig : AttrDef } //===---------------------------------------------------------------------===// -// iree_codegen.encoding_layout +// iree_codegen.encoding_nop_layout //===---------------------------------------------------------------------===// def IREECodegen_EncodingNopLayoutAttr : diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index fc85ae1c41f8..d0d75a18d40b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -340,6 +340,33 @@ def IREEGPU_GPUEncodingLayoutAttr : ); } +//===----------------------------------------------------------------------===// +// iree_gpu.gpu_pad_layout +//===----------------------------------------------------------------------===// + +def IREEGPU_GPUPadLayoutAttr : AttrDef { + let mnemonic = "gpu_pad_layout"; + let summary = "The padded encoding layout attribute for GPU targets."; + let assemblyFormat = "`<` struct(params) `>`"; + + let description = [{ + Describes padding preferences for a given GPU target. + This attribute can implement any encoding interface for data-tiling, + e.g., Encoding::EncodingLayoutAttrInterface, etc. They should be implemented + through external model mechanism because we do not want to relocate + domain-specific logic to the dialect implementation, and we can have better + code structure. See the implementation in + compiler/Codegen/ExternalInterfaces/*. + }]; + + let parameters = (ins + // Relevant target properties that will later allow us to decide the + // serialized pad layout. + "uint32_t":$cache_line_bytes, + "uint32_t":$cache_sets + ); +} + //===----------------------------------------------------------------------===// // Workgroup processor level description //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index 127c85ed37c5..eb6d306a6107 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -9,8 +9,12 @@ #include #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" namespace mlir::iree_compiler::IREE::GPU { @@ -707,6 +711,18 @@ TargetAttr getHIPTargetDetails(StringRef target, StringRef features, return nullptr; } +Attribute getHIPTargetEncodingLayoutAttr(TargetAttr target) { + // This is only enabled for CDNA2 and CDNA3 for the time being. + // TODO(kuhar): Enable for other HIP targets. + if (!llvm::is_contained({"gfx90a", "gfx940", "gfx941", "gfx942"}, + target.getArch())) { + return nullptr; + } + + return IREE::GPU::GPUPadLayoutAttr::get( + target.getContext(), /*cacheLineBytes=*/128, /*cacheSets=*/4); +} + StringRef normalizeHIPTarget(StringRef target) { return normalizeAMDGPUTarget(target); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h index d9698cc912f0..0beb4cb62508 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h @@ -36,6 +36,10 @@ StringRef normalizeCUDATarget(StringRef target); TargetAttr getHIPTargetDetails(llvm::StringRef target, llvm::StringRef features, MLIRContext *context); +// Returns an attribute implementing `EncodingLayoutAttributeInterface` if +// |target| has known encoding preferences. +Attribute getHIPTargetEncodingLayoutAttr(TargetAttr target); + // Normalizes the given HIP |target| to the gfx target commonly used for // compiling towards HIP. For example, "gfx90a" for "cnda2", "gfx1100" for // "rx7900xtx". Returns empty StringRef if the given |target| is not recognized. diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp index b279db3712bc..95509e442617 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp @@ -670,17 +670,9 @@ struct CPUDeviceEncodingLayoutAttrInterface } }; -struct CPUHostEncodingLayoutAttrInterface - : public IREE::Encoding::EncodingLayoutAttrInterface::ExternalModel< +struct CPUHostEncodingLayoutAttrInterface final + : IREE::Encoding::EncodingLayoutAttrInterface::ExternalModel< CPUHostEncodingLayoutAttrInterface, CPUEncodingLayoutAttr> { - - Value calculateStorageSizeInBytes(Attribute attr, Location loc, - OpBuilder &builder, RankedTensorType type, - ValueRange dynamicDims) const { - return calculateStorageSizeInBytesImpl(attr, loc, builder, type, - dynamicDims); - } - Attribute cloneWithSimplifiedConfig(Attribute attr, DictionaryAttr config) const { MLIRContext *ctx = attr.getContext(); @@ -697,6 +689,18 @@ struct CPUHostEncodingLayoutAttrInterface } }; +struct CPUHostSerializedEncodingLayoutAttrInterface final + : IREE::Encoding::SerializedEncodingLayoutAttrInterface::ExternalModel< + CPUHostSerializedEncodingLayoutAttrInterface, CPUEncodingLayoutAttr> { + + Value calculateStorageSizeInBytes(Attribute attr, Location loc, + OpBuilder &builder, RankedTensorType type, + ValueRange dynamicDims) const { + return calculateStorageSizeInBytesImpl(attr, loc, builder, type, + dynamicDims); + } +}; + //===----------------------------------------------------------------------===// // Interface methods implementaion for iree_cpu.vmvx_encoding_layout. //===----------------------------------------------------------------------===// @@ -731,8 +735,8 @@ enumerateVMVXMatmulTiles(linalg::ContractionDimensions cDims, }; } -struct VMVXDeviceEncodingLayoutAttrInterface - : public Codegen::LayoutAttrInterface::ExternalModel< +struct VMVXDeviceEncodingLayoutAttrInterface final + : Codegen::LayoutAttrInterface::ExternalModel< VMVXDeviceEncodingLayoutAttrInterface, VMVXEncodingLayoutAttr> { MaterializeEncodingInfo getEncodingInfo(Attribute attr, RankedTensorType type) const { @@ -797,16 +801,9 @@ struct VMVXDeviceEncodingLayoutAttrInterface } }; -struct VMVXHostEncodingLayoutAttrInterface - : public IREE::Encoding::EncodingLayoutAttrInterface::ExternalModel< +struct VMVXHostEncodingLayoutAttrInterface final + : IREE::Encoding::EncodingLayoutAttrInterface::ExternalModel< VMVXHostEncodingLayoutAttrInterface, VMVXEncodingLayoutAttr> { - Value calculateStorageSizeInBytes(Attribute attr, Location loc, - OpBuilder &builder, RankedTensorType type, - ValueRange dynamicDims) const { - return calculateStorageSizeInBytesImpl(attr, loc, builder, type, - dynamicDims); - } - Attribute cloneWithSimplifiedConfig(Attribute attr, DictionaryAttr config) const { MLIRContext *ctx = attr.getContext(); @@ -822,6 +819,18 @@ struct VMVXHostEncodingLayoutAttrInterface } }; +struct VMVXHostSerializedEncodingLayoutAttrInterface final + : IREE::Encoding::SerializedEncodingLayoutAttrInterface::ExternalModel< + VMVXHostSerializedEncodingLayoutAttrInterface, + VMVXEncodingLayoutAttr> { + Value calculateStorageSizeInBytes(Attribute attr, Location loc, + OpBuilder &builder, RankedTensorType type, + ValueRange dynamicDims) const { + return calculateStorageSizeInBytesImpl(attr, loc, builder, type, + dynamicDims); + } +}; + } // namespace void registerCPUEncodingExternalModels(DialectRegistry ®istry) { @@ -829,10 +838,12 @@ void registerCPUEncodingExternalModels(DialectRegistry ®istry) { +[](MLIRContext *ctx, IREE::CPU::IREECPUDialect *dialect) { IREE::CPU::CPUEncodingLayoutAttr::attachInterface< CPUDeviceEncodingLayoutAttrInterface, - CPUHostEncodingLayoutAttrInterface>(*ctx); + CPUHostEncodingLayoutAttrInterface, + CPUHostSerializedEncodingLayoutAttrInterface>(*ctx); IREE::CPU::VMVXEncodingLayoutAttr::attachInterface< VMVXDeviceEncodingLayoutAttrInterface, - VMVXHostEncodingLayoutAttrInterface>(*ctx); + VMVXHostEncodingLayoutAttrInterface, + VMVXHostSerializedEncodingLayoutAttrInterface>(*ctx); }); } diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp index 8920e4cd030d..4bb0fc04f967 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp @@ -5,8 +5,9 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception //===- GPUEncodingExternalModels.cpp --------------------------------------===// // -// This file implements the IREE::Codegen::LayoutAttrInterface for GPU backends. -// Different from CPU backends, we do not tranpose narrow-N to narrow-M for a +// This file implements the IREE::Codegen::LayoutAttrInterface and +// IREE::Encoding::EncodingLayoutAttrInterface for GPU backends. +// Different from CPU backends, we do not transpose narrow-N to narrow-M for a // combination of reasons: // // 1. As linalg.matmul materializes into iree_gpu.multi_mma, which inherits @@ -21,17 +22,22 @@ #include "iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.h" -#include - -#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" + +#include +#include +#include +#include #define DEBUG_TYPE "iree-gpu-encoding-external-models" @@ -333,6 +339,94 @@ struct GPUDeviceEncodingLayoutAttrInterface } }; +struct GPUPadEncodingLayoutAttrInterface final + : Encoding::EncodingLayoutAttrInterface::ExternalModel< + GPUPadEncodingLayoutAttrInterface, GPUPadLayoutAttr> { + Attribute cloneWithSimplifiedConfig(Attribute attr, + DictionaryAttr /*config*/) const { + // This attribute is self-contained and does not need to look anything up + // from the target `config`. + return attr; + } + + Attribute getLayout(Attribute attr, RankedTensorType type) const { + MLIRContext *ctx = attr.getContext(); + auto padLayoutAttr = cast(attr); + auto encodingAttr = cast(type.getEncoding()); + + const int64_t rank = type.getRank(); + SmallVector padValues(rank, 0); + auto noPaddingAttr = Encoding::PadEncodingLayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, padValues)); + if (encodingAttr.getOpType().getValue() != + IREE::Encoding::EncodingOpType::matmul) { + // We only support simple matmuls for now. + return noPaddingAttr; + } + + const int64_t operandIndex = encodingAttr.getOperandIndex().getInt(); + if (!llvm::is_contained({0, 1}, operandIndex)) { + // We only have to pad matmul operands. + return noPaddingAttr; + } + + // We only support simple matmuls for now. Filter out everything that + // does not have a simple row-major access pattern with a single static + // reduction dimension. + FailureOr contractionDims = + Encoding::getEncodingContractionDims(encodingAttr); + if (failed(contractionDims) || contractionDims->k.size() != 1) { + return noPaddingAttr; + } + + std::optional padDimensionIndex = + encodingAttr.mapDimToOperandIndex(contractionDims->k[0]); + if (!padDimensionIndex || padDimensionIndex != rank - 1) { + return noPaddingAttr; + } + ArrayRef shape = type.getShape(); + if (ShapedType::isDynamic(shape[*padDimensionIndex])) { + return noPaddingAttr; + } + + const int64_t elementBits = type.getElementTypeBitWidth(); + const int64_t cacheLineBytes = padLayoutAttr.getCacheLineBytes(); + if (elementBits % 8 != 0 || elementBits > cacheLineBytes) { + // We do not support unaligned element types. + return noPaddingAttr; + } + + // Attempt to maximize L1 cache bandwidth by engaging all cache sets. + // We want to make sure that the reduction dimension is a multiple of the + // cache line, but not a multiple of cache line * cache sets. This way the + // next 'row' will start at a different cache set. + const int64_t cacheSetSpanBytes = + padLayoutAttr.getCacheSets() * cacheLineBytes; + const int64_t dimSizeInBytes = + type.getDimSize(*padDimensionIndex) * (elementBits / 8); + int64_t padBytes = 0; + if (int64_t unalignedBytes = dimSizeInBytes % cacheLineBytes; + unalignedBytes != 0) { + // First, pad to the multiple of cache lines. + padBytes += cacheLineBytes - unalignedBytes; + } + + if ((dimSizeInBytes + padBytes) % cacheSetSpanBytes == 0) { + // Pad by one cache line to engage all cache sets. + padBytes += cacheLineBytes; + } + + assert((dimSizeInBytes + padBytes) % cacheLineBytes == 0 && + "Incorrect pad amount"); + assert(padBytes < cacheSetSpanBytes && "Incorrect pad amount"); + const int64_t numPadElements = (padBytes * 8) / elementBits; + padValues[*padDimensionIndex] = numPadElements; + auto padLayout = Encoding::PadEncodingLayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, padValues)); + return padLayout; + } +}; + } // namespace void registerGPUEncodingExternalModels(DialectRegistry ®istry) { @@ -340,6 +434,8 @@ void registerGPUEncodingExternalModels(DialectRegistry ®istry) { +[](MLIRContext *ctx, IREE::GPU::IREEGPUDialect *dialect) { IREE::GPU::GPUEncodingLayoutAttr::attachInterface< GPUDeviceEncodingLayoutAttrInterface>(*ctx); + IREE::GPU::GPUPadLayoutAttr::attachInterface< + GPUPadEncodingLayoutAttrInterface>(*ctx); }); } diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Interfaces.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Interfaces.cpp index 2f3b36377ec5..b098197f91b4 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Interfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Interfaces.cpp @@ -12,8 +12,8 @@ namespace mlir::iree_compiler { void registerCodegenExternalInterfaces(DialectRegistry ®istry) { - IREE::GPU::registerGPUEncodingExternalModels(registry); IREE::CPU::registerCPUEncodingExternalModels(registry); + IREE::GPU::registerGPUEncodingExternalModels(registry); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 2c7a9d593dda..b76fd1dbc0a8 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -158,18 +158,15 @@ Value EncodingAttr::calculateStorageSizeInBytes(Location loc, OpBuilder &builder, RankedTensorType type, ValueRange dynamicDims) const { - if (auto layoutsAttr = getLayouts()) { - if (llvm::any_of(layoutsAttr.getValue(), [](Attribute attr) { - return !llvm::isa(attr); - })) { - return Value(); + if (ArrayAttr layoutsAttr = getLayouts()) { + if (!llvm::all_of(layoutsAttr.getValue(), + llvm::IsaPred)) { + return nullptr; } - auto layoutsAttrArray = - llvm::to_vector_of( - layoutsAttr.getValue()); Value res; - for (auto attr : layoutsAttrArray) { + for (auto attr : + layoutsAttr.getAsRange()) { Value requestedSize = attr.calculateStorageSizeInBytes(loc, builder, type, dynamicDims); if (!res) { @@ -313,6 +310,14 @@ std::string stringifyOperandIndex(IntegerAttr valueAttr) { } } +Value PadEncodingLayoutAttr::calculateStorageSizeInBytes( + Location loc, OpBuilder &builder, RankedTensorType type, + ValueRange dynamicDims) const { + // TODO(kuhar): Add sizeof calculation. + assert(false && "Unimplemented"); + return nullptr; +} + //===---------------------------------------------------------------------===// // Encoding specialization attributes, which are mainly for testing purpose. //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td index 4fb6de70bcea..868e58fb9116 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td @@ -54,7 +54,7 @@ def PackedStorageAttr : IREEEncoding_Attr<"PackedStorage"> { def EncodingAttr : IREEEncoding_Attr<"Encoding", [ - DeclareAttrInterfaceMethods ]> { @@ -133,6 +133,35 @@ def EncodingAttr : let genVerifyDecl = 0; } +//===---------------------------------------------------------------------===// +// encoding.pad_encoding_layout +//===---------------------------------------------------------------------===// + +def PadEncodingLayoutAttr : IREEEncoding_Attr<"PadEncodingLayout", [ + DeclareAttrInterfaceMethods + ]> { + let mnemonic = "pad_encoding_layout"; + let assemblyFormat = "`<` $padding `>`"; + + let summary = "An attribute that encodes padding values of tensor dimensions"; + let description = [{ + Associates tensor dimensions with pad values (numbers of appended elements). + The logical dimensions of the tensors do not change, and the elements in the + padded regions are left uninitialized. + + This attribute implements `Encoding::SerializedEncodingLayoutAttrInterface`, + to provide a hook for tensor sizeof lowering. The implementation of this + interface is backend-agnostic, but the emission of the pad encoding attribute + itself can be target- or domain-specific. + }]; + + let parameters = (ins + // How many padding elements to add along each tensor dimension. + "DenseI32ArrayAttr":$padding + ); +} + //===---------------------------------------------------------------------===// // Encoding specialization attributes, which are mainly for testing purpose. //===---------------------------------------------------------------------===// @@ -175,7 +204,7 @@ def SpecializedEncodingAttr : let summary = "An attribute that indicates the encoding is specialized"; let description = [{ - This attribute is similar to UnspecializeEncodingAttr, but with an optional + This attribute is similar to UnspecializedEncodingAttr, but with an optional type. The attribute denotes the layout of the type. Different seed values indicate different layouts, which can be used to emulate different encoding attributes. diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td index ff5033057d0d..d5905b17be6d 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td @@ -25,30 +25,10 @@ def IREEEncoding_EncodingLayoutAttrInterface : based on their needs. TBD. The current expectation of the interface is to propagate layout - information from backends to the host compliation or other targets. + information from backends to the host compilation or other targets. }]; let methods = [ - InterfaceMethod< - /*desc=*/[{ - Returns the storage size (in bytes) for the tensor types with an - optional encoding. Returns Value() if the size is unknown, i.e., it can - not be inferred with existing information. - }], - /*retTy=*/"::mlir::Value", - /*methodName=*/"calculateStorageSizeInBytes", - /*args=*/(ins - "::mlir::Location":$loc, - "::mlir::OpBuilder &":$builder, - "RankedTensorType":$type, - "ValueRange":$dynamicDims - ), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(false && "unimplemented interface method"); - return {}; - }] - >, InterfaceMethod< /*desc=*/[{ Returns the attribute with simplified configuration/layouts. Attribute @@ -57,7 +37,7 @@ def IREEEncoding_EncodingLayoutAttrInterface : configuration (e.g., cpu features) for further lowering. However, some configuration/parameters can be dropped as long as they are no longer needed in the progressively lowering. This method provides a mechanism - for such attribute to drop the outdated paramters and makes IR dump less + for such attribute to drop the outdated parameters and makes IR dump less verbose. }], /*retTy=*/"::mlir::Attribute", @@ -73,8 +53,9 @@ def IREEEncoding_EncodingLayoutAttrInterface : >, InterfaceMethod< /*desc=*/[{ - Returns the serialized layout, which is either common format or wrapped - by an attribute that implements EncodingLayoutAttrInterface interface. + Returns the an attribute implementing the which is either common format + or wrapped by an attribute that implements the + `SerializedEncodingLayoutAttrInterface` interface. If it is in common format (e.g., a regular tensor type), we can easily calculate the storage size. Otherwise, we will need a hook from external, and the hook can come from an attribute that implements the @@ -94,6 +75,41 @@ def IREEEncoding_EncodingLayoutAttrInterface : ]; } +def IREEEncoding_SerializedEncodingAttrInterface : + AttrInterface<"SerializedEncodingLayoutAttrInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Encoding"; + let description = [{ + Interface used to query serialized layout information needed to materialize + encoding attributes. + + The attributes implementing this interface may be target-specific or general + enough to be shared across backends, depending on the layouts used. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the storage size (in bytes) for the tensor types with an + optional encoding. Returns Value() if the size is unknown, i.e., it can + not be inferred with existing information. + }], + /*retTy=*/"::mlir::Value", + /*methodName=*/"calculateStorageSizeInBytes", + /*args=*/(ins + "::mlir::Location":$loc, + "::mlir::OpBuilder &":$builder, + "RankedTensorType":$type, + "ValueRange":$dynamicDims + ), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(false && "unimplemented interface method"); + return {}; + }] + > + ]; +} + //===----------------------------------------------------------------------===// // Type Interfaces //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir index 520f7eabde7b..2c71a86e1639 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir @@ -144,6 +144,62 @@ module attributes {stream.affinity.default = #hal.device.affinity<@device_a>} { // CHECK: stream.tensor.dispatch on(#hal.device.affinity<@[[$DEVICE_B]]>) @[[$EX0]]::@dispatch // CHECK-SAME: #[[$ENCODING]] +//------------------------------------------------------------------------------ +// iree_gpu.gpu_pad_encoding specialization tests. +// These get serialized to iree_encoding.pad_encoding_layout attributes. +//------------------------------------------------------------------------------ + +// ----- + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (n, k)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#map3 = affine_map<(m, n, k) -> (n, k)> +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", + encoding = #iree_gpu.gpu_pad_layout, ukernels = "none"}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_rocm_hsaco_fb]> : !hal.device +#encodingA = #iree_encoding.encoding +#encodingB = #iree_encoding.encoding +#encodingC = #iree_encoding.encoding +#encodingD = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + util.func public @with_pad_encoding(%arg0: index, %arg1: index, %scalar_f32 : f32) { + %0 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf16, #encodingA>{} in !stream.resource<*>{%arg1} + %1 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x4160xf16, #encodingA>{} in !stream.resource<*>{%arg1} + %2 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x1337xf16, #encodingA>{} in !stream.resource<*>{%arg1} + %3 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x4095xf16, #encodingA>{} in !stream.resource<*>{%arg1} + %4 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor{%arg0} in !stream.resource<*>{%arg1} + %5 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor{%arg0, %arg1} in !stream.resource<*>{%arg1} + %6 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf16, #encodingB>{} in !stream.resource<*>{%arg1} + %7 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf16, #encodingC>{} in !stream.resource<*>{%arg1} + %8 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf16, #encodingD>{} in !stream.resource<*>{%arg1} + util.return + } +} + +// CHECK-DAG: #[[$NO_PAD_LHS:.+]] = #iree_encoding.encoding] +// CHECK-DAG: #[[$NO_PAD_OUT:.+]] = #iree_encoding.encoding] +// CHECK-DAG: #[[$PAD_LHS_0:.+]] = #iree_encoding.encoding] +// CHECK-DAG: #[[$PAD_LHS_1:.+]] = #iree_encoding.encoding] +// CHECK-DAG: #[[$PAD_LHS_2:.+]] = #iree_encoding.encoding] +// CHECK-DAG: #[[$PAD_RHS:.+]] = #iree_encoding.encoding] + +// CHECK-LABEL: util.func public @with_pad_encoding +// +// CHECK: stream.tensor.empty {{.*}} : tensor<4096x4096xf16, #[[$PAD_LHS_0]]> +// CHECK: stream.tensor.empty {{.*}} : tensor<4096x4160xf16, #[[$NO_PAD_LHS]]> +// CHECK: stream.tensor.empty {{.*}} : tensor<4096x1337xf16, #[[$PAD_LHS_1]]> +// CHECK: stream.tensor.empty {{.*}} : tensor<4096x4095xf16, #[[$PAD_LHS_2]]> +// CHECK: stream.tensor.empty {{.*}} : tensor +// CHECK: stream.tensor.empty {{.*}} : tensor +// CHECK: stream.tensor.empty {{.*}} : tensor<4096x4096xf16, #[[$PAD_RHS]]> +// CHECK: stream.tensor.empty {{.*}} : tensor<4096x4096xf16, #[[$NO_PAD_OUT]]> +// CHECK: stream.tensor.empty {{.*}} : tensor<4096x4096xf16, #[[$PAD_RHS]]> +// +// CHECK-NEXT: util.return + // ----- // Tests that launch the executable on device_a, pass the result to device_b and diff --git a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp index 9fe2e7986b5e..d7b3258d5e45 100644 --- a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp @@ -90,9 +90,8 @@ Value calculateStorageElementCountInBytes(Location loc, ValueRange dynamicDims, OpBuilder &builder) { Attribute encoding = shapedType.getEncoding(); - if (auto encodingLayoutAttr = - dyn_cast_or_null( - encoding)) { + if (auto encodingLayoutAttr = dyn_cast_or_null< + IREE::Encoding::SerializedEncodingLayoutAttrInterface>(encoding)) { return encodingLayoutAttr.calculateStorageSizeInBytes( loc, builder, shapedType, dynamicDims); }