From 852684a23661dd4c9fdca733f7520cf3aa28f372 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 8 Jan 2024 17:58:05 -0500 Subject: [PATCH] [spirv][vulkan] Enable device query generation and execution (#15977) This commit adds a pass to materialize executable required SPIR-V capabilities into proper device queries inside the associated hal.executable.condition ops. Linking is updated accordingly to unique and preseve the feature checks. The Vulkan HAL driver is updated accordingly to probe the implementation and match against the device queries. Fixes https://github.com/openxla/iree/issues/15786 --- .../iree/compiler/Codegen/SPIRV/BUILD.bazel | 1 + .../compiler/Codegen/SPIRV/CMakeLists.txt | 1 + .../iree/compiler/Codegen/SPIRV/Passes.cpp | 10 + .../src/iree/compiler/Codegen/SPIRV/Passes.h | 5 + .../src/iree/compiler/Codegen/SPIRV/Passes.td | 9 + .../Codegen/SPIRV/SPIRVLinkExecutables.cpp | 29 +- .../SPIRVMaterializeExecutableConditions.cpp | 320 ++++++++++++++++++ .../SPIRV/SPIRVSelectLoweringStrategy.cpp | 6 - .../SPIRVTileAndVectorizeToCooperativeOps.cpp | 19 +- .../SPIRV/SPIRVTrimExecutableTargetEnv.cpp | 19 +- .../src/iree/compiler/Codegen/SPIRV/Utils.cpp | 17 +- .../src/iree/compiler/Codegen/SPIRV/Utils.h | 4 + .../compiler/Codegen/SPIRV/test/BUILD.bazel | 1 + .../Codegen/SPIRV/test/CMakeLists.txt | 1 + .../Codegen/SPIRV/test/link_executables.mlir | 87 +++-- .../materialize_executable_conditions.mlir | 236 +++++++++++++ .../iree/compiler/Codegen/Utils/BUILD.bazel | 1 + .../compiler/Codegen/Utils/CMakeLists.txt | 1 + .../compiler/Codegen/Utils/LinkingUtils.cpp | 25 ++ .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 12 + .../iree/compiler/Dialect/HAL/IR/HALOps.td | 5 + .../drivers/vulkan/dynamic_symbol_tables.h | 2 +- .../hal/drivers/vulkan/extensibility_util.h | 36 ++ .../src/iree/hal/drivers/vulkan/handle_util.h | 8 + .../iree/hal/drivers/vulkan/vulkan_device.cc | 231 ++++++++++++- 25 files changed, 1000 insertions(+), 86 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp create mode 100644 compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel index 6e1f20720755..86edaec83ec6 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel @@ -66,6 +66,7 @@ iree_compiler_cc_library( "SPIRVLinkExecutables.cpp", "SPIRVLowerExecutableTargetPass.cpp", "SPIRVMapMemRefStorageClass.cpp", + "SPIRVMaterializeExecutableConditions.cpp", "SPIRVSelectLoweringStrategy.cpp", "SPIRVTile.cpp", "SPIRVTileAndDistribute.cpp", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt index 7d4d5229432b..6fd74650453a 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt @@ -65,6 +65,7 @@ iree_cc_library( "SPIRVLinkExecutables.cpp" "SPIRVLowerExecutableTargetPass.cpp" "SPIRVMapMemRefStorageClass.cpp" + "SPIRVMaterializeExecutableConditions.cpp" "SPIRVSelectLoweringStrategy.cpp" "SPIRVTile.cpp" "SPIRVTileAndDistribute.cpp" diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index 3049d59151c7..13f7dcdd55ef 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -682,6 +682,16 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath) { // NOTE: this runs on the top-level program module containing all hal.executable // ops. void buildSPIRVLinkingPassPipeline(OpPassManager &passManager) { + auto &nestedExecutablePM = passManager.nest(); + // Trim the allowed target environment (version/capability/extension/etc.) to + // the minimal requirement needed by compiled spirv.module ops. This helps to + // increase the chance of linking different variant ops together. + nestedExecutablePM.addNestedPass( + createSPIRVTrimExecutableTargetEnvPass()); + // Materialize the minimal required target environment into proper device + // queries to execute in the runtime. + nestedExecutablePM.addNestedPass( + createSPIRVMaterializeExecutableConditionsPass()); // Link together executables. This may produce some IR duplication. passManager.addPass(createSPIRVLinkExecutablesPass()); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h index 63b8c14e738b..dc4a22c2c3c7 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h @@ -134,6 +134,11 @@ createSPIRVLowerExecutableTargetPass(); std::unique_ptr> createSPIRVMapMemRefStorageClassPass(); +/// Pass to materialize SPIR-V target requirements of hal.exectuable.variant ops +/// into hal.executable.condition regions. +std::unique_ptr> +createSPIRVMaterializeExecutableConditionsPass(); + /// Pass to tile and distribute Linalg ops with buffer semantics to /// invocations. std::unique_ptr> createSPIRVTileAndDistributePass(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td index 1456ed625c6a..8ec18269f8d8 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td @@ -91,6 +91,15 @@ def SPIRVMapMemRefStorageClass : let constructor = "mlir::iree_compiler::createSPIRVMapMemRefStorageClassPass()"; } +def SPIRVMaterializeExecutableConditions : + Pass<"iree-spirv-materialize-executable-conditions", + "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> { + let summary = "Materialize SPIR-V target requirements of hal.exectuable.variant " + "ops into hal.executable.condition regions"; + let constructor = + "mlir::iree_compiler::createSPIRVMaterializeExecutableConditionsPass()"; +} + def SPIRVSelectLoweringStrategy : Pass<"iree-spirv-select-lowering-strategy-pass", "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp index df77b24c3ee3..bd389783ebba 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp @@ -6,15 +6,14 @@ #include "iree/compiler/Codegen/SPIRV/PassDetail.h" #include "iree/compiler/Codegen/SPIRV/Passes.h" +#include "iree/compiler/Codegen/SPIRV/Utils.h" #include "iree/compiler/Codegen/Utils/LinkingUtils.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Utils/ModuleUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" -#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Pass/Pass.h" #define DEBUG_TYPE "iree-spirv-link-executable" @@ -22,23 +21,19 @@ namespace mlir::iree_compiler { namespace IREE::HAL { -// Compares two ExecutableTargetAttr according to the order of used SPIR-V -// capabilities. +// Compares two ExecutableTargetAttr according to the alphabetical order of used +// SPIR-V features. // // Note that this is a very specific ordering per the needs of this pass--we // guarantee that input ExectuableTargetAttr only differ w.r.t. their used // SPIR-V features, and we want a deterministic order when mutating the IR. bool operator<(const ExecutableTargetAttr &a, const ExecutableTargetAttr &b) { - auto aTarget = a.getConfiguration().getAs( - spirv::getTargetEnvAttrName()); - auto bTarget = b.getConfiguration().getAs( - spirv::getTargetEnvAttrName()); - auto aFeatures = aTarget.getCapabilitiesAttr(); - auto bFeatures = bTarget.getCapabilitiesAttr(); + auto aFeatures = a.getConfiguration().getAs("iree.spirv.features"); + auto bFeatures = b.getConfiguration().getAs("iree.spirv.features"); for (unsigned i = 0; i < std::min(aFeatures.size(), bFeatures.size()); ++i) { if (aFeatures[i] != bFeatures[i]) { - return cast(aFeatures[i]).getInt() < - cast(bFeatures[i]).getInt(); + return cast(aFeatures[i]).getValue() < + cast(bFeatures[i]).getValue(); } } return aFeatures.size() < bFeatures.size(); @@ -49,11 +44,6 @@ namespace { using IREE::HAL::ExecutableTargetAttr; -bool isSPIRVBasedBackend(IREE::HAL::ExecutableVariantOp variantOp) { - return variantOp.getTargetAttr().getConfiguration().contains( - spirv::getTargetEnvAttrName()); -} - struct SPIRVLinkExecutablesPass final : SPIRVLinkExecutablesBase { void runOnOperation() override { @@ -104,9 +94,8 @@ struct SPIRVLinkExecutablesPass final // sort as the unique key. currentTargets.clear(); for (auto variant : executable.getOps()) { - ExecutableTargetAttr target = variant.getTarget(); - if (isSPIRVBasedBackend(variant)) { - currentTargets.push_back(target); + if (usesSPIRVCodeGen(variant)) { + currentTargets.push_back(variant.getTarget()); } } llvm::sort(currentTargets); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp new file mode 100644 index 000000000000..182e65e8a279 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp @@ -0,0 +1,320 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/SPIRV/PassDetail.h" +#include "iree/compiler/Codegen/SPIRV/Passes.h" +#include "iree/compiler/Codegen/SPIRV/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +namespace { + +// The list of device features potentially required by a particular kernel. +// +// Note that the fields used here should match the ones used in +// iree_hal_vulkan_device_properties_t on the runtime side. +struct KernelFeatures { + // Floating-point compute related feature bitfield: + // * 0b01: f16 + // * 0b10: f64 + // Note that f32 is assumed to always exist and does not appear in this + // bitfield. + uint32_t computeFloat; + // Integer compute related feature bitfield: + // * 0b001: i8 + // * 0b010: i16 + // * 0b100: i64 + // Note that i32 or i1 is assumed to always exist and does not appear in + // this bitfield. + uint32_t computeInt; + // Storage bitwidth requirement bitfiled: + // * 0b01: 8-bit + // * 0b10: 16-bit + uint32_t storage; + // Subgroup operation requirement bitfield: + // * 0b01: subgroup shuffle operations + // * 0b10: subgroup arithmetic operations + uint32_t subgroup; + // Dot product operation requirement bitfield: + // ("dotprod..") + // * 0b01: dotprod.4xi8.i32 + uint32_t dotProduct; + // Cooperative matrix requirement bitfield: + // ("coopmatrix...xx") + // * 0b01: coopmatrix.f16.f16.16x16x16 + uint32_t coopMatrix; + + KernelFeatures() + : computeFloat(0), computeInt(0), storage(0), subgroup(0), dotProduct(0), + coopMatrix(0) {} + + bool empty() const { + return computeFloat == 0 && computeInt == 0 && storage == 0 && + subgroup == 0 && dotProduct == 0 && coopMatrix == 0; + } +}; + +// Maps the given SPIR-V capability to the corresponding device query feature +// and updates features. +// +// Note that the device queries used here should match the ones used in +// iree_hal_vulkan_get_device_properties() on the runtime side. +LogicalResult mapToDeviceQuery(IREE::HAL::ExecutableExportOp entryPoint, + spirv::Capability cap, + KernelFeatures &features) { + switch (cap) { + case spirv::Capability::Shader: + // The shader capability is the root capability for graphics APIs. + // So just ignore. + return success(); + + //===-------------------------------------------------------------------===// + // Compute capabilities + case spirv::Capability::Float16: + features.computeFloat |= 0b01; + return success(); + case spirv::Capability::Float64: + features.computeFloat |= 0b10; + return success(); + case spirv::Capability::Int8: + features.computeInt |= 0b001; + return success(); + case spirv::Capability::Int16: + features.computeInt |= 0b010; + return success(); + case spirv::Capability::Int64: + features.computeInt |= 0b100; + return success(); + + //===-------------------------------------------------------------------===// + // Storage capabilities + case spirv::Capability::UniformAndStorageBuffer8BitAccess: + case spirv::Capability::StorageBuffer8BitAccess: + // These capabilities allow 8-bit types to appear in interface variables of + // a particular storage class. + // So cluster them together. + features.storage |= 0b01; + return success(); + case spirv::Capability::StorageBuffer16BitAccess: + case spirv::Capability::StorageUniform16: + // These capabilities allow 16-bit types to appear in interface variables of + // a particular storage class. + // So cluster them together. + features.storage |= 0b10; + return success(); + + //===-------------------------------------------------------------------===// + // Subgroup capabilities + case spirv::Capability::GroupNonUniform: + // The basic subgroup capability provides access to builtin variables like + // subgroup ID and size. + // * In Vulkan, this is mandated starting v1.1. + // * In Metal, we have it since v2.2. + // So just ignore. + return success(); + case spirv::Capability::GroupNonUniformShuffle: + features.subgroup |= 0b01; + return success(); + case spirv::Capability::GroupNonUniformArithmetic: + features.subgroup |= 0b10; + return success(); + + case spirv::Capability::DotProduct: + case spirv::Capability::DotProductInput4x8Bit: + // We only ever use vector<4xi8> -> i32 variant of dot product right now. + features.dotProduct |= 0b1; + return success(); + + //===-------------------------------------------------------------------===// + // Cooperative matrix capabilities + case spirv::Capability::CooperativeMatrixKHR: { + // Cooperative matrix has many device specific configurations. They are not + // directly reflected in the SPIR-V capabilities. We need to be explicit by + // looking at the chosen configuration. + // Format: "coopmatrix...xx". + auto coopmatType = + entryPoint->getAttrOfType("iree.spirv.coopmatrix.type"); + auto coopmatShape = entryPoint->getAttrOfType( + "iree.spirv.coopmatrix.shape"); + if (!coopmatType || !coopmatShape) + return failure(); + + Type inputType = cast(coopmatType.getValue().front()).getValue(); + Type outputType = cast(coopmatType.getValue().back()).getValue(); + int64_t mSize = coopmatShape.asArrayRef()[0]; + int64_t nSize = coopmatShape.asArrayRef()[1]; + int64_t kSize = coopmatShape.asArrayRef()[2]; + + // We explicitly perform exact match here given that 1) we need to have the + // corresponding query in the runtime, and 2) we are not using a lot of + // configuarations in CodeGen yet. + if (inputType.isF16() && outputType.isF16()) { + if (mSize == 16 && nSize == 16 && kSize == 16) { + features.coopMatrix |= 0b1; + return success(); + } + } + + return success(); + } + + default: + break; + } + return failure(); +} + +// Builds the device query ops using the given builder. +// +// Note that the device queries used here should match the ones used in +// iree_hal_vulkan_device_query_i64() on the runtime side. +void buildDeviceQueryRegion(const KernelFeatures &features, Value device, + Location loc, OpBuilder &builder) { + IntegerType boolType = builder.getI1Type(); + IntegerType i32Type = builder.getI32Type(); + TypedAttr zeroAttr = builder.getZeroAttr(i32Type); + + auto buildQueryOp = [&](const char *key, uint32_t value, Value result) { + auto queryOp = builder.create( + loc, boolType, i32Type, device, builder.getStringAttr("hal.dispatch"), + builder.getStringAttr(key), zeroAttr); + auto zero = builder.create(loc, 0, 32); + auto val = builder.create(loc, value, 32); + auto andOp = builder.create(loc, queryOp.getValue(), val); + auto cmpOp = builder.create(loc, arith::CmpIPredicate::ne, + andOp, zero); + // Verify that 1) the query succeeds and 2) the capability is supported. + auto ok = builder.create(loc, queryOp.getOk(), cmpOp); + return builder.create(loc, result, ok).getResult(); + }; + + Value result = builder.create(loc, true, 1); + if (features.computeFloat) { + result = buildQueryOp("compute.f", features.computeFloat, result); + } + if (features.computeInt) { + result = buildQueryOp("compute.i", features.computeInt, result); + } + if (features.storage) { + result = buildQueryOp("storage", features.storage, result); + } + if (features.subgroup) { + result = buildQueryOp("subgroup", features.subgroup, result); + } + if (features.dotProduct) { + result = buildQueryOp("dotprod", features.dotProduct, result); + } + if (features.coopMatrix) { + result = buildQueryOp("coopmatrix", features.coopMatrix, result); + } + builder.create(loc, result); +} + +// Returns the device queries as a list of unique keys. +SmallVector getDeviceQueries(const KernelFeatures &features) { + SmallVector queries; + if (features.computeFloat) { + queries.push_back("compute.f=" + std::to_string(features.computeFloat)); + } + if (features.computeInt) { + queries.push_back("compute.i=" + std::to_string(features.computeInt)); + } + if (features.storage) { + queries.push_back("storage=" + std::to_string(features.storage)); + } + if (features.subgroup) { + queries.push_back("subgroup=" + std::to_string(features.subgroup)); + } + if (features.dotProduct) { + queries.push_back("dotprod=" + std::to_string(features.dotProduct)); + } + if (features.coopMatrix) { + queries.push_back("coopmatrix=" + std::to_string(features.coopMatrix)); + } + return queries; +} + +struct SPIRVMaterializeExecutableConditionsPass final + : SPIRVMaterializeExecutableConditionsBase< + SPIRVMaterializeExecutableConditionsPass> { + void runOnOperation() override { + IREE::HAL::ExecutableVariantOp variantOp = getOperation(); + if (!usesSPIRVCodeGen(variantOp)) + return; + + IREE::HAL::ExecutableTargetAttr executableTarget = variantOp.getTarget(); + DictionaryAttr configuration = executableTarget.getConfiguration(); + auto spirvTarget = configuration.getAs( + spirv::getTargetEnvAttrName()); + + auto exportOps = variantOp.getOps(); + if (!llvm::hasSingleElement(exportOps)) { + variantOp.emitError("expected to contain exactly one export op"); + return signalPassFailure(); + } + IREE::HAL::ExecutableExportOp exportOp = *exportOps.begin(); + + // Map all required SPIR-V capabilities to device queries and unique them. + // Here we only consider capabilities--version/extension is just the spec + // "container" for them; so we can ignore. + KernelFeatures features; + for (spirv::Capability cap : spirvTarget.getCapabilities()) { + if (failed(mapToDeviceQuery(exportOp, cap, features))) { + variantOp.emitError("failed to handle capability ") + << spirv::stringifyCapability(cap); + return signalPassFailure(); + } + } + + OpBuilder builder(variantOp); + + // Build the hal.executable.condition op inside the variant. + if (!features.empty()) { + Value device = variantOp.createConditionOp(builder); + buildDeviceQueryRegion(features, device, device.getLoc(), builder); + } + + // Build a string list of the used queries too--this is useful for attaching + // to the executable target attribute as a unique key for the linking pass. + SmallVector strings = getDeviceQueries(features); + SmallVector queries; + queries.reserve(strings.size() + 1); + queries.push_back(variantOp.getTarget().getBackend().getValue()); + for (const std::string &s : strings) { + queries.push_back(s); + } + + // Drop the fine-grained SPIR-V target and add the course-grained device + // queries as a list. + auto dictKeyValues = llvm::to_vector(llvm::make_filter_range( + configuration.getValue(), [](NamedAttribute attr) { + return attr.getName() != spirv::getTargetEnvAttrName(); + })); + dictKeyValues.emplace_back(builder.getStringAttr("iree.spirv.features"), + builder.getStrArrayAttr(queries)); + variantOp.setTargetAttr(IREE::HAL::ExecutableTargetAttr::get( + executableTarget.getContext(), executableTarget.getBackend(), + executableTarget.getFormat(), + DictionaryAttr::get(configuration.getContext(), dictKeyValues))); + } +}; + +} // namespace + +std::unique_ptr> +createSPIRVMaterializeExecutableConditionsPass() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp index ba8cc0cb1393..849e90e63c2e 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp @@ -12,20 +12,14 @@ #include "iree/compiler/Codegen/SPIRV/Passes.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Passes.h" - -#define DEBUG_TYPE "iree-spirv-select-lowering-strategy-pass" namespace mlir::iree_compiler { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp index a4c966e525a6..a916d2519ae8 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp @@ -27,9 +27,7 @@ #include "iree/compiler/Codegen/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" -#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -40,7 +38,6 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" -#include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -68,17 +65,23 @@ static SmallVector getTargetCooperativeOpSize(linalg::LinalgOp op) { return getTileSizes(op, 3); // For native vector sizes } -constexpr char coopMatShapeAttrName[] = "iree.spirv.coop_mat_shape"; +constexpr char coopMatTypeAttrName[] = "iree.spirv.coopmatrix.type"; +constexpr char coopMatShapeAttrName[] = "iree.spirv.coopmatrix.shape"; -/// Sets the chosen cooperative matrix shape for CodeGen onto the +/// Sets the chosen cooperative matrix type/shape for CodeGen onto the /// hal.executable.export op for the given `funcOp`. -void setSPIRVCooperativeMatrixShape(func::FuncOp funcOp, - ArrayRef shape) { +void setSPIRVCooperativeMatrixInfo(func::FuncOp funcOp, linalg::LinalgOp rootOp, + ArrayRef shape) { auto moduleOp = funcOp->getParentOfType(); auto exportOp = getAllEntryPoints(moduleOp).lookup(funcOp.getName()); Builder b(funcOp.getContext()); exportOp->setAttr(coopMatShapeAttrName, b.getDenseI64ArrayAttr(shape)); + auto inputType = cast(rootOp.getDpsInputs().front().getType()); + auto outputType = cast(rootOp.getDpsInits().front().getType()); + auto elementTypes = b.getTypeArrayAttr( + {inputType.getElementType(), outputType.getElementType()}); + exportOp->setAttr(coopMatTypeAttrName, elementTypes); } /// Returns the chosen cooperative matrix shape for CodeGen from the @@ -357,7 +360,7 @@ class SPIRVTileToCooperativeOpsPass final // given that after tiling and vectorization we won't have the root Linalg // op anymore. SmallVector cooperativeOpSize = getTargetCooperativeOpSize(rootOp); - setSPIRVCooperativeMatrixShape(funcOp, cooperativeOpSize); + setSPIRVCooperativeMatrixInfo(funcOp, rootOp, cooperativeOpSize); SmallVector subgroupCounts = deduceSubgroupCounts(rootOp); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTrimExecutableTargetEnv.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTrimExecutableTargetEnv.cpp index e247d9a199fa..9d07ebcd7d66 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTrimExecutableTargetEnv.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTrimExecutableTargetEnv.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/SPIRV/PassDetail.h" #include "iree/compiler/Codegen/SPIRV/Passes.h" +#include "iree/compiler/Codegen/SPIRV/Utils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -16,23 +17,15 @@ namespace mlir::iree_compiler { namespace { -bool IsSPIRVBasedBackend(StringRef backend) { - return backend.starts_with("vulkan") || backend.starts_with("metal") || - backend.starts_with("webgpu"); -} - struct SPIRVTrimExecutableTargetEnvPass final : SPIRVTrimExecutableTargetEnvBase { void runOnOperation() override { IREE::HAL::ExecutableVariantOp variant = getOperation(); - if (!IsSPIRVBasedBackend(variant.getTarget().getBackend())) { - return; - } - if (variant.getObjects().has_value()) { - // Ignore external executable variants. We need to read spirv.module - // ops to get the deduced minimal list of required capability and - // extension. External source executables won't have any spirv.module - // ops inside. + if (!usesSPIRVCodeGen(variant)) { + // Ignore variants not targeting SPIR-V or external executable variants. + // We need to read spirv.module ops to get the deduced minimal list of + // required capability and extension. External source executables won't + // have any spirv.module ops inside. return; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp index 8fdb97fb37eb..92e15c40a4eb 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp @@ -11,15 +11,28 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/IR/BuiltinAttributes.h" namespace mlir::iree_compiler { +bool usesSPIRVCodeGen(IREE::HAL::ExecutableVariantOp variantOp) { + if (variantOp.getObjects().has_value()) { + // Variants containing external executables do not go through CodeGen. + return false; + } + + DictionaryAttr configuration = variantOp.getTargetAttr().getConfiguration(); + // The spirv.target_env attribute is attached if going down SPIR-V CodeGen + // pipelines. Later we turn spirv.target_env into iree.spirv.features after + // materializing device queries. + return configuration.contains(spirv::getTargetEnvAttrName()) || + configuration.contains("iree.spirv.features"); +} + const char *getSPIRVDistributeAttrName() { return "iree.spirv.distribute_dim"; } spirv::TargetEnvAttr getSPIRVTargetEnvAttr(Operation *op) { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h index b5ddd0ebe1be..275a882afa66 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h @@ -13,6 +13,7 @@ #ifndef IREE_COMPILER_CODEGEN_SPIRV_UTILS_H_ #define IREE_COMPILER_CODEGEN_SPIRV_UTILS_H_ +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" @@ -21,6 +22,9 @@ namespace mlir::iree_compiler { +// Returns true if the given variant op uses SPIR-V CodeGen. +bool usesSPIRVCodeGen(IREE::HAL::ExecutableVariantOp variantOp); + /// Returns the attribute name carrying information about distribution. const char *getSPIRVDistributeAttrName(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel index be2d14a8ce2f..d1d63c2ebccd 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel @@ -52,6 +52,7 @@ iree_lit_test_suite( "lowering_scalar_dispatch.mlir", "lowering_reduction.mlir", "map_memref_storage_class.mlir", + "materialize_executable_conditions.mlir", "pipeline_matmul_cooperative_ops.mlir", "pipeline_matmul_promotion.mlir", "pipeline_matmul_vectorization.mlir", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt index fd9df00cfb35..fdea4975c58b 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt @@ -48,6 +48,7 @@ iree_lit_test_suite( "lowering_reduction.mlir" "lowering_scalar_dispatch.mlir" "map_memref_storage_class.mlir" + "materialize_executable_conditions.mlir" "pipeline_matmul_cooperative_ops.mlir" "pipeline_matmul_promotion.mlir" "pipeline_matmul_vectorization.mlir" diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir index ac1850449299..bd9145eaec7f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir @@ -6,9 +6,7 @@ // // For such case we can link all executables into one, with just one variant. -#vulkan_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> +#vulkan_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]}> #pipeline_layout = #hal.pipeline.layout i32 as "foo" +// CHECK: hal.executable.constant.block(%arg0: !hal.device) -> i32 as "foo" // CHECK-NEXT: = arith.constant 1 // CHECK: hal.executable.export public @dispatch_0 ordinal(0) // CHECK: hal.return %c1, %c1, %c1 @@ -165,11 +163,9 @@ util.initializer { // having one variant containing all entry points needing the same target. #vulkan_target_0 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan"]}> #vulkan_target_1 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan", "subgroup=1"]}> #pipeline_layout = #hal.pipeline.layout i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = arith.andi %value, %c1_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.constant.block(%device: !hal.device) -> i32 as "baz" { %c2 = arith.constant 2 : i32 hal.return %c2 : i32 @@ -236,6 +241,15 @@ hal.executable private @dispatch_2 { } hal.executable private @dispatch_3 { hal.executable.variant @spirv target(#vulkan_target_1) { + hal.executable.condition(%arg0: !hal.device) -> i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = arith.andi %value, %c1_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.export @dispatch_3 ordinal(0) layout(#pipeline_layout) { ^bb0(%arg0: !hal.device) : %c1 = arith.constant 1 : index @@ -262,10 +276,8 @@ func.func @two_target_environments() -> () { return } -// CHECK: #[[TARGET0:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce -// CHECK: #[[TARGET1:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce +// CHECK-DAG: #[[TARGET0:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]} +// CHECK-DAG: #[[TARGET1:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "subgroup=1"]} // CHECK: hal.executable private @link_executables_linked_spirv_0 { // CHECK: hal.executable.variant public @vulkan_spirv_fb target(#[[TARGET0]]) { @@ -289,6 +301,8 @@ func.func @two_target_environments() -> () { // CHECK: } // CHECK: hal.executable private @link_executables_linked_spirv_1 { // CHECK: hal.executable.variant public @vulkan_spirv_fb target(#[[TARGET1]]) { +// CHECK: hal.executable.condition(%arg0: !hal.device) -> i1 +// CHECK-NEXT: hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") // CHECK: hal.executable.constant.block(%arg0: !hal.device) -> i32 as "baz" // CHECK-NEXT: = arith.constant 2 : i32 // CHECK: hal.executable.export public @dispatch_1 ordinal(0) @@ -324,14 +338,11 @@ func.func @two_target_environments() -> () { // same set of target requirements. #vulkan_target_0 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan"]}> #vulkan_target_1 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan", "subgroup=1"]}> #vulkan_target_2 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan", "subgroup=2"]}> #pipeline_layout = #hal.pipeline.layout i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = arith.andi %value, %c1_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.constant.block(%device: !hal.device) -> i32 as "foo" { %c2 = arith.constant 2 : i32 hal.return %c2 : i32 @@ -400,6 +420,15 @@ hal.executable private @dispatch_1 { } } hal.executable.variant @spirv_1 target(#vulkan_target_1) { + hal.executable.condition(%arg0: !hal.device) -> i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = arith.andi %value, %c1_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.constant.block(%device: !hal.device) -> i32 as "baz" { %c4 = arith.constant 4 : i32 hal.return %c4 : i32 @@ -454,6 +483,15 @@ hal.executable private @dispatch_3 { } } hal.executable.variant @spirv_1 target(#vulkan_target_2) { + hal.executable.condition(%arg0: !hal.device) -> i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %0 = arith.andi %value, %c2_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.export @dispatch_3 ordinal(0) layout(#pipeline_layout) { ^bb0(%arg0: !hal.device) : %c1 = arith.constant 1 : index @@ -469,12 +507,9 @@ hal.executable private @dispatch_3 { } } -// CHECK: #[[TARGET0:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce -// CHECK: #[[TARGET1:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce -// CHECK: #[[TARGET2:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce +// CHECK-DAG: #[[TARGET0:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]} +// CHECK-DAG: #[[TARGET1:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "subgroup=1"]} +// CHECK-DAG: #[[TARGET2:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "subgroup=2"]} // CHECK: hal.executable private @link_executables_linked_spirv { // CHECK: hal.executable.variant public @vulkan_spirv_fb_0 target(#[[TARGET0]]) { @@ -500,6 +535,10 @@ hal.executable private @dispatch_3 { // CHECK: } // CHECK: } // CHECK: hal.executable.variant public @vulkan_spirv_fb_1 target(#[[TARGET1]]) { +// CHECK: hal.executable.condition(%arg0: !hal.device) -> i1 +// CHECK-NEXT: %{{.+}}, %[[V:.+]] = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") +// CHECK: %[[TARGET:.+]] = arith.constant 1 : i32 +// CHECK-NEXT: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 // CHECK: hal.executable.constant.block(%arg0: !hal.device) -> i32 as "foo" // CHECK-NEXT: = arith.constant 2 : i32 // CHECK: hal.executable.export public @dispatch_0 ordinal(0) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir new file mode 100644 index 000000000000..c922dd5724f5 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir @@ -0,0 +1,236 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-materialize-executable-conditions)))' --mlir-print-local-scope %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout, + <1, storage_buffer, ReadOnly>, + <2, storage_buffer> + ]> +]> + +hal.executable private @dispatch_executable { + // CHECK-LABEL: hal.executable.variant public @test_assumed_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]}>) + // CHECK-NOT: hal.executable.condition + hal.executable.variant public @test_assumed_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_assumed_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_assumed_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_assumed_capabilities + spirv.ExecutionMode @test_assumed_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_subgroup_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "subgroup=3"]}>) + // CHECK-NEXT: hal.executable.condition(%[[DEV:.+]]: !hal.device) -> i1 { + // CHECK-NEXT: %[[T:.+]] = arith.constant true + // CHECK-NEXT: %[[OK:.+]], %[[V:.+]] = hal.device.query<%[[DEV]] : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "subgroup") : i1, i32 = 0 : i32 + // CHECK-NEXT: %[[ZERO:.+]] = arith.constant 0 : i32 + // CHECK-NEXT: %[[TARGET:.+]] = arith.constant 3 : i32 + // CHECK-NEXT: %[[CHECK:.+]] = arith.andi %[[V]], %[[TARGET]] : i32 + // CHECK-NEXT: %[[CMP:.+]] = arith.cmpi ne, %[[CHECK]], %[[ZERO]] : i32 + // CHECK-NEXT: %[[AND:.+]] = arith.andi %[[OK]], %[[CMP]] : i1 + // CHECK-NEXT: %[[RESULT:.+]] = arith.andi %[[T]], %[[AND]] : i1 + // CHECK-NEXT: hal.return %[[RESULT]] : i1 + // CHECK-NEXT: } + hal.executable.variant public @test_subgroup_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_subgroup_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_subgroup_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_subgroup_capabilities + spirv.ExecutionMode @test_subgroup_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_8bit_storage_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "storage=1"]}>) + // CHECK-NEXT: hal.executable.condition(%[[DEV:.+]]: !hal.device) -> i1 { + // CHECK-NEXT: %[[T:.+]] = arith.constant true + // CHECK-NEXT: %[[OK:.+]], %[[V:.+]] = hal.device.query<%[[DEV]] : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "storage") : i1, i32 = 0 : i32 + // CHECK-NEXT: %[[ZERO:.+]] = arith.constant 0 : i32 + // CHECK-NEXT: %[[TARGET:.+]] = arith.constant 1 : i32 + // CHECK-NEXT: %[[CHECK:.+]] = arith.andi %[[V]], %[[TARGET]] : i32 + // CHECK-NEXT: %[[CMP:.+]] = arith.cmpi ne, %[[CHECK]], %[[ZERO]] : i32 + // CHECK-NEXT: %[[AND:.+]] = arith.andi %[[OK]], %[[CMP]] : i1 + // CHECK-NEXT: %[[RESULT:.+]] = arith.andi %[[T]], %[[AND]] : i1 + // CHECK-NEXT: hal.return %[[RESULT]] : i1 + // CHECK-NEXT: } + hal.executable.variant public @test_8bit_storage_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_8bit_storage_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires + #spirv.vce { + spirv.func @test_8bit_storage_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_8bit_storage_capabilities + spirv.ExecutionMode @test_8bit_storage_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_16bit_storage_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "storage=2"]}>) + // CHECK-NEXT: hal.executable.condition(%[[DEV:.+]]: !hal.device) -> i1 { + // CHECK-NEXT: %[[T:.+]] = arith.constant true + // CHECK-NEXT: %[[OK:.+]], %[[V:.+]] = hal.device.query<%[[DEV]] : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "storage") : i1, i32 = 0 : i32 + // CHECK-NEXT: %[[ZERO:.+]] = arith.constant 0 : i32 + // CHECK-NEXT: %[[TARGET:.+]] = arith.constant 2 : i32 + // CHECK-NEXT: %[[CHECK:.+]] = arith.andi %[[V]], %[[TARGET]] : i32 + // CHECK-NEXT: %[[CMP:.+]] = arith.cmpi ne, %[[CHECK]], %[[ZERO]] : i32 + // CHECK-NEXT: %[[AND:.+]] = arith.andi %[[OK]], %[[CMP]] : i1 + // CHECK-NEXT: %[[RESULT:.+]] = arith.andi %[[T]], %[[AND]] : i1 + // CHECK-NEXT: hal.return %[[RESULT]] : i1 + // CHECK-NEXT: } + hal.executable.variant public @test_16bit_storage_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_16bit_storage_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires + #spirv.vce { + spirv.func @test_16bit_storage_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_16bit_storage_capabilities + spirv.ExecutionMode @test_16bit_storage_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_int_compute_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "compute.i=7"]}>) + // CHECK: %{{.+}}, %[[V:.+]] = hal.device.query<%{{.+}} : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "compute.i") : i1, i32 = 0 : i32 + // CHECK: %[[TARGET:.+]] = arith.constant 7 : i32 + // CHECK: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 + hal.executable.variant public @test_int_compute_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_int_compute_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_int_compute_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_int_compute_capabilities + spirv.ExecutionMode @test_int_compute_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_float_compute_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "compute.f=3"]}>) + // CHECK: %{{.+}}, %[[V:.+]] = hal.device.query<%{{.+}} : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "compute.f") : i1, i32 = 0 : i32 + // CHECK: %[[TARGET:.+]] = arith.constant 3 : i32 + // CHECK: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 + hal.executable.variant public @test_float_compute_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_float_compute_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_float_compute_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_float_compute_capabilities + spirv.ExecutionMode @test_float_compute_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_dot_product_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "dotprod=1"]}>) + // CHECK: %{{.+}}, %[[V:.+]] = hal.device.query<%{{.+}} : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "dotprod") : i1, i32 = 0 : i32 + // CHECK: %[[TARGET:.+]] = arith.constant 1 : i32 + // CHECK: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 + hal.executable.variant public @test_dot_product_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_dot_product_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_dot_product_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_dot_product_capabilities + spirv.ExecutionMode @test_dot_product_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_cooperative_matrix_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "coopmatrix=1"]}>) + // CHECK: %{{.+}}, %[[V:.+]] = hal.device.query<%{{.+}} : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "coopmatrix") : i1, i32 = 0 : i32 + // CHECK: %[[TARGET:.+]] = arith.constant 1 : i32 + // CHECK: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 + hal.executable.variant public @test_cooperative_matrix_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_cooperative_matrix_capabilities ordinal(0) layout(#pipeline_layout) attributes { + iree.spirv.coopmatrix.shape = array, iree.spirv.coopmatrix.type = [f16, f16] + } { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_cooperative_matrix_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_cooperative_matrix_capabilities + spirv.ExecutionMode @test_cooperative_matrix_capabilities "LocalSize", 64, 1, 1 + } + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel index dd5dc93c46cd..823a0956b917 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel @@ -38,6 +38,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface", "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Utils", "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", "//llvm-external-projects/iree-dialects:IREELinalgExtPasses", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt index d73ef76dba30..1e7fa2fc7f06 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt @@ -57,6 +57,7 @@ iree_cc_library( iree::compiler::Codegen::Interfaces::UKernelOpInterface iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp index c2df59e9fccf..9fd7492a9c4f 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp @@ -7,6 +7,8 @@ #include "iree/compiler/Codegen/Utils/LinkingUtils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Utils/EquivalenceUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" @@ -235,6 +237,29 @@ LogicalResult linkExecutablesInto( {SymbolRefAttr::get(linkedTargetOp)}); symbolReplacements.variantRefs[oldVariantRefAttr] = newVariantRefAttr; + // Move the condition op too. We need to make sure all variant's condition + // op has the same content. + auto targetConditionOps = + linkedTargetOp.getOps(); + if (auto sourceCoditionOp = variantOp.getConditionOp()) { + if (targetConditionOps.empty()) { + sourceCoditionOp->moveBefore( + &*linkedTargetBuilder.getInsertionPoint()); + } else { + assert(llvm::hasSingleElement(targetConditionOps)); + IREE::HAL::ExecutableConditionOp referenceOp = + *targetConditionOps.begin(); + if (!isStructurallyEquivalentTo(*sourceCoditionOp.getOperation(), + *referenceOp.getOperation())) { + return variantOp.emitError("contains incompatible condition op"); + } + } + } else { + if (!targetConditionOps.empty()) { + return variantOp.emitError("should contain a condition op"); + } + } + // Move any constant blocks that need to be preserved for future host // translation. There may be duplicates provided but they'll be cleaned // up in future passes. diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 0eb86d2baf09..25d26c1e6295 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -1220,6 +1220,18 @@ DenseMap ExecutableVariantOp::gatherConstantOrdinals() { return map; } +Value ExecutableVariantOp::createConditionOp(OpBuilder &builder) { + assert(!getConditionOp() && "condition op already exists"); + + builder.setInsertionPointToStart(&getRegion().front()); + auto conditionOp = builder.create(getLoc()); + Block *entryPoint = conditionOp.addEntryBlock(); + Value device = entryPoint->getArgument(0); + + builder.setInsertionPointToStart(entryPoint); + return device; +} + Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) { // Base case dependent on target information. // TODO(multi-device): condition on device target ID and other queries that diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 91bda119787e..164ffa28ea2c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -2162,6 +2162,11 @@ def HAL_ExecutableVariantOp : HAL_Op<"executable.variant", [ // blocks inside the variant. DenseMap gatherConstantOrdinals(); + // Creates the new `hal.executable.condition` op in this variant op and sets + // the insertion point of the provided builder to the beginning of the new + // region. + Value createConditionOp(OpBuilder &builder); + // Returns an i1 indicating whether this variant should be selected. Value buildCondition(Value device, OpBuilder &builder); }]; diff --git a/runtime/src/iree/hal/drivers/vulkan/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/vulkan/dynamic_symbol_tables.h index 8fb4403c28e9..6497a44c2aa7 100644 --- a/runtime/src/iree/hal/drivers/vulkan/dynamic_symbol_tables.h +++ b/runtime/src/iree/hal/drivers/vulkan/dynamic_symbol_tables.h @@ -319,7 +319,7 @@ namespace vulkan { INS_PFN(EXCLUDED, vkGetDisplayPlaneCapabilitiesKHR) \ INS_PFN(EXCLUDED, vkGetDisplayPlaneSupportedDisplaysKHR) \ INS_PFN(OPTIONAL, vkGetPhysicalDeviceCalibrateableTimeDomainsEXT) \ - INS_PFN(EXCLUDED, vkGetPhysicalDeviceCooperativeMatrixPropertiesNV) \ + INS_PFN(OPTIONAL, vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR) \ INS_PFN(EXCLUDED, vkGetPhysicalDeviceDisplayPlaneProperties2KHR) \ INS_PFN(EXCLUDED, vkGetPhysicalDeviceDisplayPlanePropertiesKHR) \ INS_PFN(EXCLUDED, vkGetPhysicalDeviceDisplayProperties2KHR) \ diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h index 562b6513c498..4c12179c1894 100644 --- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h +++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h @@ -106,4 +106,40 @@ iree_hal_vulkan_device_extensions_t iree_hal_vulkan_infer_enabled_device_extensions( const iree::hal::vulkan::DynamicSymbols* device_syms); +// Struct for supported device properties. +// +// Note that the fields used here should match the ones used in KernelFeatures +// on the compiler side. +typedef struct iree_hal_vulkan_device_properties_t { + // Floating-point compute related feature bitfield: + // * 0b01: f16 + // * 0b10: f64 + // Note that f32 is assumed to always exist and does not appear in this + // bitfield. + uint32_t compute_float : 8; + // Integer compute related feature bitfield: + // * 0b001: i8 + // * 0b010: i16 + // * 0b100: i64 + // Note that i32 or i1 is assumed to always exist and does not appear in + // this bitfield. + uint32_t compute_int : 8; + // Storage bitwidth requirement bitfiled: + // * 0b01: 8-bit + // * 0b10: 16-bit + uint32_t storage : 8; + // Subgroup operation requirement bitfield: + // * 0b01: subgroup shuffle operations + // * 0b10: subgroup arithmetic operations + uint32_t subgroup : 8; + // Dot product operation requirement bitfield: + // ("dotprod..") + // * 0b01: dotprod.4xi8.i32 + uint32_t dot_product : 8; + // Cooperative matrix requirement bitfield: + // ("coopmatrix...xx") + // * 0b01: coopmatrix.f16.f16.16x16x16 + uint32_t cooperative_matrix : 8; +} iree_hal_vulkan_iree_hal_vulkan_device_properties_t; + #endif // IREE_HAL_DRIVERS_VULKAN_EXTENSIBILITY_UTIL_H_ diff --git a/runtime/src/iree/hal/drivers/vulkan/handle_util.h b/runtime/src/iree/hal/drivers/vulkan/handle_util.h index 9cf0349e04ee..c85bcb624ec4 100644 --- a/runtime/src/iree/hal/drivers/vulkan/handle_util.h +++ b/runtime/src/iree/hal/drivers/vulkan/handle_util.h @@ -42,12 +42,14 @@ class VkDeviceHandle : public RefObject { VkDeviceHandle(DynamicSymbols* syms, VkPhysicalDevice physical_device, iree_hal_vulkan_features_t enabled_features, iree_hal_vulkan_device_extensions_t enabled_extensions, + iree_hal_vulkan_device_properties_t supported_properties, bool owns_device, iree_allocator_t host_allocator, const VkAllocationCallbacks* allocator = nullptr) : physical_device_(physical_device), syms_(add_ref(syms)), enabled_features_(enabled_features), enabled_extensions_(enabled_extensions), + supported_properties_(supported_properties), owns_device_(owns_device), allocator_(allocator), host_allocator_(host_allocator) {} @@ -62,6 +64,7 @@ class VkDeviceHandle : public RefObject { value_(exchange(other.value_, static_cast(VK_NULL_HANDLE))), syms_(std::move(other.syms_)), enabled_extensions_(other.enabled_extensions_), + supported_properties_(other.supported_properties_), owns_device_(other.owns_device_), allocator_(other.allocator_), host_allocator_(other.host_allocator_) {} @@ -93,12 +96,17 @@ class VkDeviceHandle : public RefObject { return enabled_extensions_; } + const iree_hal_vulkan_device_properties_t& supported_properties() const { + return supported_properties_; + } + private: VkPhysicalDevice physical_device_ = VK_NULL_HANDLE; VkDevice value_ = VK_NULL_HANDLE; ref_ptr syms_; iree_hal_vulkan_features_t enabled_features_; iree_hal_vulkan_device_extensions_t enabled_extensions_; + iree_hal_vulkan_device_properties_t supported_properties_; bool owns_device_; const VkAllocationCallbacks* allocator_ = nullptr; iree_allocator_t host_allocator_; diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc index faffd5a1dc30..41c2aac07b37 100644 --- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc +++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc @@ -498,6 +498,8 @@ typedef struct iree_hal_vulkan_device_t { iree_hal_vulkan_device_flags_t flags; // Which optional extensions are active and available on the device. iree_hal_vulkan_device_extensions_t device_extensions; + // Device properties for various optional features. + iree_hal_vulkan_device_properties_t device_properties; VkInstance instance; VkPhysicalDevice physical_device; @@ -690,6 +692,7 @@ static iree_status_t iree_hal_vulkan_device_create_internal( const iree_hal_vulkan_device_options_t* options, VkInstance instance, VkPhysicalDevice physical_device, VkDeviceHandle* logical_device, const iree_hal_vulkan_device_extensions_t* device_extensions, + const iree_hal_vulkan_device_properties_t* device_properties, const iree_hal_vulkan_queue_set_t* compute_queue_set, const iree_hal_vulkan_queue_set_t* transfer_queue_set, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { @@ -721,6 +724,7 @@ static iree_status_t iree_hal_vulkan_device_create_internal( device->flags = options->flags; device->device_extensions = *device_extensions; + device->device_properties = *device_properties; device->instance = instance; device->physical_device = physical_device; device->logical_device = logical_device; @@ -846,6 +850,160 @@ static iree_status_t iree_hal_vulkan_device_query_extensibility_set( return iree_ok_status(); } +static iree_status_t iree_hal_vulkan_get_device_properties( + DynamicSymbols* instance_syms, VkPhysicalDevice physical_device, + iree_hal_vulkan_device_properties_t* device_properties) { + memset(device_properties, 0, sizeof(*device_properties)); + + VkPhysicalDeviceFeatures2 physical_device_features; + memset(&physical_device_features, 0, sizeof(physical_device_features)); + physical_device_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + + // + Shader float16 and int8 features. + VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_features; + memset(&shader_float16_int8_features, 0, + sizeof(shader_float16_int8_features)); + shader_float16_int8_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES; + shader_float16_int8_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &shader_float16_int8_features; + + // + Shader 8 bit storage features. + VkPhysicalDevice8BitStorageFeatures supported_8bit_storage_features; + memset(&supported_8bit_storage_features, 0, + sizeof(supported_8bit_storage_features)); + supported_8bit_storage_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES; + supported_8bit_storage_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &supported_8bit_storage_features; + + // + Shader 16 bit storage features. + VkPhysicalDevice16BitStorageFeatures supported_16bit_storage_features; + memset(&supported_16bit_storage_features, 0, + sizeof(supported_16bit_storage_features)); + supported_16bit_storage_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES; + supported_16bit_storage_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &supported_16bit_storage_features; + + // + Shader integer dot product features. + VkPhysicalDeviceShaderIntegerDotProductFeatures dot_product_features; + memset(&dot_product_features, 0, sizeof(dot_product_features)); + dot_product_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES; + dot_product_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &dot_product_features; + + // + Cooperative matrix features. + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coop_matrix_features; + memset(&coop_matrix_features, 0, sizeof(coop_matrix_features)); + coop_matrix_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coop_matrix_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &coop_matrix_features; + + instance_syms->vkGetPhysicalDeviceFeatures2(physical_device, + &physical_device_features); + + VkPhysicalDeviceProperties2 physical_device_properties; + memset(&physical_device_properties, 0, sizeof(physical_device_properties)); + physical_device_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; + physical_device_properties.pNext = NULL; + + // + Subgroup properties. + VkPhysicalDeviceSubgroupProperties subgroup_properties; + memset(&subgroup_properties, 0, sizeof(subgroup_properties)); + subgroup_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + subgroup_properties.pNext = physical_device_properties.pNext; + physical_device_properties.pNext = &subgroup_properties; + + // + Shader integer dot product properties. + VkPhysicalDeviceShaderIntegerDotProductProperties dot_product_properties; + memset(&dot_product_properties, 0, sizeof(dot_product_properties)); + dot_product_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_PROPERTIES; + dot_product_properties.pNext = physical_device_properties.pNext; + physical_device_properties.pNext = &dot_product_properties; + + instance_syms->vkGetPhysicalDeviceProperties2(physical_device, + &physical_device_properties); + + if (shader_float16_int8_features.shaderFloat16) { + device_properties->compute_float |= 0x1u; + } + if (physical_device_features.features.shaderFloat64) { + device_properties->compute_float |= 0x2u; + } + if (shader_float16_int8_features.shaderInt8) { + device_properties->compute_int |= 0x1u; + } + if (physical_device_features.features.shaderInt16) { + device_properties->compute_int |= 0x2u; + } + if (physical_device_features.features.shaderInt64) { + device_properties->compute_int |= 0x4u; + } + if (supported_8bit_storage_features.storageBuffer8BitAccess && + supported_8bit_storage_features.uniformAndStorageBuffer8BitAccess) { + device_properties->storage |= 0x1u; + } + if (supported_16bit_storage_features.storageBuffer16BitAccess && + supported_16bit_storage_features.uniformAndStorageBuffer16BitAccess) { + device_properties->storage |= 0x2u; + } + + if (iree_all_bits_set(subgroup_properties.supportedOperations, + VK_SUBGROUP_FEATURE_SHUFFLE_BIT)) { + device_properties->subgroup |= 0x1u; + } + if (iree_all_bits_set(subgroup_properties.supportedOperations, + VK_SUBGROUP_FEATURE_ARITHMETIC_BIT)) { + device_properties->subgroup |= 0x2u; + } + + if (dot_product_features.shaderIntegerDotProduct && + dot_product_properties.integerDotProduct8BitUnsignedAccelerated && + dot_product_properties.integerDotProduct8BitSignedAccelerated && + dot_product_properties.integerDotProduct8BitMixedSignednessAccelerated && + dot_product_properties + .integerDotProductAccumulatingSaturating8BitUnsignedAccelerated && + dot_product_properties + .integerDotProductAccumulatingSaturating8BitSignedAccelerated && + dot_product_properties + .integerDotProductAccumulatingSaturating8BitMixedSignednessAccelerated) { + device_properties->dot_product |= 0x1u; + } + + if (coop_matrix_features.cooperativeMatrix && + instance_syms->vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR) { + uint32_t count = 0; + IREE_RETURN_IF_ERROR(VK_RESULT_TO_STATUS( + instance_syms->vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR( + physical_device, &count, NULL))); + VkCooperativeMatrixPropertiesKHR* properties = + (VkCooperativeMatrixPropertiesKHR*)iree_alloca( + count * sizeof(VkCooperativeMatrixPropertiesKHR)); + IREE_RETURN_IF_ERROR(VK_RESULT_TO_STATUS( + instance_syms->vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR( + physical_device, &count, properties))); + for (uint32_t i = 0; i < count; ++i) { + VkCooperativeMatrixPropertiesKHR* p = properties + i; + if (p->AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + p->BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + if (p->CType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + if (p->MSize == 16 && p->NSize == 16 && p->KSize == 16) { + device_properties->cooperative_matrix |= 0x1u; + } + } + } + } + } + + return iree_ok_status(); +} + iree_status_t iree_hal_vulkan_device_create( iree_hal_driver_t* driver, iree_string_view_t identifier, iree_hal_vulkan_features_t requested_features, @@ -966,6 +1124,15 @@ iree_status_t iree_hal_vulkan_device_create( available_shader_float16_int8_features.pNext = available_features2.pNext; available_features2.pNext = &available_shader_float16_int8_features; + // + Subgroup matrix features. + VkPhysicalDeviceSubgroupProperties available_subgroup_properties; + memset(&available_subgroup_properties, 0, + sizeof(available_subgroup_properties)); + available_subgroup_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + available_subgroup_properties.pNext = available_features2.pNext; + available_features2.pNext = &available_subgroup_properties; + // + Cooperative matrix features. VkPhysicalDeviceCooperativeMatrixFeaturesKHR available_coop_matrix_features; memset(&available_coop_matrix_features, 0, @@ -1094,15 +1261,19 @@ iree_status_t iree_hal_vulkan_device_create( enabled_features2.pNext = &available_shader_float16_int8_features; } - // Enable all available coop matrix features. + // Enable all available cooperative matrix features. if (enabled_device_extensions.cooperative_matrix) { available_coop_matrix_features.pNext = enabled_features2.pNext; enabled_features2.pNext = &available_coop_matrix_features; } + iree_hal_vulkan_device_properties_t device_properties; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_get_device_properties( + instance_syms, physical_device, &device_properties)); + auto logical_device = new VkDeviceHandle( instance_syms, physical_device, enabled_features, - enabled_device_extensions, + enabled_device_extensions, device_properties, /*owns_device=*/true, host_allocator, /*allocator=*/NULL); iree_status_t status = VK_RESULT_TO_STATUS( @@ -1129,7 +1300,8 @@ iree_status_t iree_hal_vulkan_device_create( status = iree_hal_vulkan_device_create_internal( driver, identifier, enabled_features, options, instance, physical_device, logical_device, &enabled_device_extensions, - &compute_queue_set, &transfer_queue_set, host_allocator, out_device); + &device_properties, &compute_queue_set, &transfer_queue_set, + host_allocator, out_device); } logical_device->ReleaseReference(); @@ -1169,6 +1341,11 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_wrap_device( iree_hal_vulkan_device_extensions_t enabled_device_extensions = iree_hal_vulkan_infer_enabled_device_extensions(device_syms.get()); + // We can still retrieve the correct device properties though. + iree_hal_vulkan_device_properties_t device_properties; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_get_device_properties( + device_syms.get(), physical_device, &device_properties)); + iree_hal_vulkan_features_t enabled_features = 0; #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE enabled_features |= IREE_HAL_VULKAN_FEATURE_ENABLE_TRACING; @@ -1177,7 +1354,7 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_wrap_device( // Wrap the provided VkDevice with a VkDeviceHandle for use within the HAL. auto logical_device_handle = new VkDeviceHandle( device_syms.get(), physical_device, enabled_features, - enabled_device_extensions, + enabled_device_extensions, device_properties, /*owns_device=*/false, host_allocator, /*allocator=*/NULL); *logical_device_handle->mutable_value() = logical_device; @@ -1185,7 +1362,8 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_wrap_device( iree_status_t status = iree_hal_vulkan_device_create_internal( /*driver=*/NULL, identifier, enabled_features, options, instance, physical_device, logical_device_handle, &enabled_device_extensions, - compute_queue_set, transfer_queue_set, host_allocator, out_device); + &device_properties, compute_queue_set, transfer_queue_set, host_allocator, + out_device); logical_device_handle->ReleaseReference(); return status; @@ -1238,14 +1416,13 @@ static iree_status_t iree_hal_vulkan_device_query_i64( iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); *out_value = 0; - if (iree_string_view_equal(category, - iree_make_cstring_view("hal.executable.format"))) { - if (iree_string_view_equal(key, - iree_make_cstring_view("vulkan-spirv-fb"))) { + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { + if (iree_string_view_equal(key, IREE_SV("vulkan-spirv-fb"))) { // Base SPIR-V always supported. *out_value = 1; - } else if (iree_string_view_equal( - key, iree_make_cstring_view("vulkan-spirv-fb-ptr"))) { + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("vulkan-spirv-fb-ptr"))) { // SPIR-V with device addresses is optionally supported based on whether // we have device feature support. *out_value = iree_all_bits_set( @@ -1253,8 +1430,38 @@ static iree_status_t iree_hal_vulkan_device_query_i64( IREE_HAL_VULKAN_FEATURE_ENABLE_BUFFER_DEVICE_ADDRESSES) ? 1 : 0; + return iree_ok_status(); + } + } + + // Note that the device queries used here should match the ones used in + // buildDeviceQueryRegion() on the compiler side. + if (iree_string_view_equal(category, IREE_SV("hal.dispatch"))) { + if (iree_string_view_equal(key, IREE_SV("compute.f"))) { + *out_value = device->logical_device->supported_properties().compute_float; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("compute.i"))) { + *out_value = device->logical_device->supported_properties().compute_int; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("storage"))) { + *out_value = device->logical_device->supported_properties().storage; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("subgroup"))) { + *out_value = device->logical_device->supported_properties().subgroup; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("dotprod"))) { + *out_value = device->logical_device->supported_properties().dot_product; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("coopmatrix"))) { + *out_value = + device->logical_device->supported_properties().cooperative_matrix; + return iree_ok_status(); } - return iree_ok_status(); } return iree_make_status(