diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel index 6e1f20720755..86edaec83ec6 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel @@ -66,6 +66,7 @@ iree_compiler_cc_library( "SPIRVLinkExecutables.cpp", "SPIRVLowerExecutableTargetPass.cpp", "SPIRVMapMemRefStorageClass.cpp", + "SPIRVMaterializeExecutableConditions.cpp", "SPIRVSelectLoweringStrategy.cpp", "SPIRVTile.cpp", "SPIRVTileAndDistribute.cpp", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt index 7d4d5229432b..6fd74650453a 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt @@ -65,6 +65,7 @@ iree_cc_library( "SPIRVLinkExecutables.cpp" "SPIRVLowerExecutableTargetPass.cpp" "SPIRVMapMemRefStorageClass.cpp" + "SPIRVMaterializeExecutableConditions.cpp" "SPIRVSelectLoweringStrategy.cpp" "SPIRVTile.cpp" "SPIRVTileAndDistribute.cpp" diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index 3049d59151c7..13f7dcdd55ef 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -682,6 +682,16 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath) { // NOTE: this runs on the top-level program module containing all hal.executable // ops. void buildSPIRVLinkingPassPipeline(OpPassManager &passManager) { + auto &nestedExecutablePM = passManager.nest(); + // Trim the allowed target environment (version/capability/extension/etc.) to + // the minimal requirement needed by compiled spirv.module ops. This helps to + // increase the chance of linking different variant ops together. + nestedExecutablePM.addNestedPass( + createSPIRVTrimExecutableTargetEnvPass()); + // Materialize the minimal required target environment into proper device + // queries to execute in the runtime. + nestedExecutablePM.addNestedPass( + createSPIRVMaterializeExecutableConditionsPass()); // Link together executables. This may produce some IR duplication. passManager.addPass(createSPIRVLinkExecutablesPass()); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h index 63b8c14e738b..dc4a22c2c3c7 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h @@ -134,6 +134,11 @@ createSPIRVLowerExecutableTargetPass(); std::unique_ptr> createSPIRVMapMemRefStorageClassPass(); +/// Pass to materialize SPIR-V target requirements of hal.exectuable.variant ops +/// into hal.executable.condition regions. +std::unique_ptr> +createSPIRVMaterializeExecutableConditionsPass(); + /// Pass to tile and distribute Linalg ops with buffer semantics to /// invocations. std::unique_ptr> createSPIRVTileAndDistributePass(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td index 1456ed625c6a..8ec18269f8d8 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td @@ -91,6 +91,15 @@ def SPIRVMapMemRefStorageClass : let constructor = "mlir::iree_compiler::createSPIRVMapMemRefStorageClassPass()"; } +def SPIRVMaterializeExecutableConditions : + Pass<"iree-spirv-materialize-executable-conditions", + "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> { + let summary = "Materialize SPIR-V target requirements of hal.exectuable.variant " + "ops into hal.executable.condition regions"; + let constructor = + "mlir::iree_compiler::createSPIRVMaterializeExecutableConditionsPass()"; +} + def SPIRVSelectLoweringStrategy : Pass<"iree-spirv-select-lowering-strategy-pass", "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp index df77b24c3ee3..bd389783ebba 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp @@ -6,15 +6,14 @@ #include "iree/compiler/Codegen/SPIRV/PassDetail.h" #include "iree/compiler/Codegen/SPIRV/Passes.h" +#include "iree/compiler/Codegen/SPIRV/Utils.h" #include "iree/compiler/Codegen/Utils/LinkingUtils.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Utils/ModuleUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" -#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Pass/Pass.h" #define DEBUG_TYPE "iree-spirv-link-executable" @@ -22,23 +21,19 @@ namespace mlir::iree_compiler { namespace IREE::HAL { -// Compares two ExecutableTargetAttr according to the order of used SPIR-V -// capabilities. +// Compares two ExecutableTargetAttr according to the alphabetical order of used +// SPIR-V features. // // Note that this is a very specific ordering per the needs of this pass--we // guarantee that input ExectuableTargetAttr only differ w.r.t. their used // SPIR-V features, and we want a deterministic order when mutating the IR. bool operator<(const ExecutableTargetAttr &a, const ExecutableTargetAttr &b) { - auto aTarget = a.getConfiguration().getAs( - spirv::getTargetEnvAttrName()); - auto bTarget = b.getConfiguration().getAs( - spirv::getTargetEnvAttrName()); - auto aFeatures = aTarget.getCapabilitiesAttr(); - auto bFeatures = bTarget.getCapabilitiesAttr(); + auto aFeatures = a.getConfiguration().getAs("iree.spirv.features"); + auto bFeatures = b.getConfiguration().getAs("iree.spirv.features"); for (unsigned i = 0; i < std::min(aFeatures.size(), bFeatures.size()); ++i) { if (aFeatures[i] != bFeatures[i]) { - return cast(aFeatures[i]).getInt() < - cast(bFeatures[i]).getInt(); + return cast(aFeatures[i]).getValue() < + cast(bFeatures[i]).getValue(); } } return aFeatures.size() < bFeatures.size(); @@ -49,11 +44,6 @@ namespace { using IREE::HAL::ExecutableTargetAttr; -bool isSPIRVBasedBackend(IREE::HAL::ExecutableVariantOp variantOp) { - return variantOp.getTargetAttr().getConfiguration().contains( - spirv::getTargetEnvAttrName()); -} - struct SPIRVLinkExecutablesPass final : SPIRVLinkExecutablesBase { void runOnOperation() override { @@ -104,9 +94,8 @@ struct SPIRVLinkExecutablesPass final // sort as the unique key. currentTargets.clear(); for (auto variant : executable.getOps()) { - ExecutableTargetAttr target = variant.getTarget(); - if (isSPIRVBasedBackend(variant)) { - currentTargets.push_back(target); + if (usesSPIRVCodeGen(variant)) { + currentTargets.push_back(variant.getTarget()); } } llvm::sort(currentTargets); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp new file mode 100644 index 000000000000..182e65e8a279 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp @@ -0,0 +1,320 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/SPIRV/PassDetail.h" +#include "iree/compiler/Codegen/SPIRV/Passes.h" +#include "iree/compiler/Codegen/SPIRV/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +namespace { + +// The list of device features potentially required by a particular kernel. +// +// Note that the fields used here should match the ones used in +// iree_hal_vulkan_device_properties_t on the runtime side. +struct KernelFeatures { + // Floating-point compute related feature bitfield: + // * 0b01: f16 + // * 0b10: f64 + // Note that f32 is assumed to always exist and does not appear in this + // bitfield. + uint32_t computeFloat; + // Integer compute related feature bitfield: + // * 0b001: i8 + // * 0b010: i16 + // * 0b100: i64 + // Note that i32 or i1 is assumed to always exist and does not appear in + // this bitfield. + uint32_t computeInt; + // Storage bitwidth requirement bitfiled: + // * 0b01: 8-bit + // * 0b10: 16-bit + uint32_t storage; + // Subgroup operation requirement bitfield: + // * 0b01: subgroup shuffle operations + // * 0b10: subgroup arithmetic operations + uint32_t subgroup; + // Dot product operation requirement bitfield: + // ("dotprod..") + // * 0b01: dotprod.4xi8.i32 + uint32_t dotProduct; + // Cooperative matrix requirement bitfield: + // ("coopmatrix...xx") + // * 0b01: coopmatrix.f16.f16.16x16x16 + uint32_t coopMatrix; + + KernelFeatures() + : computeFloat(0), computeInt(0), storage(0), subgroup(0), dotProduct(0), + coopMatrix(0) {} + + bool empty() const { + return computeFloat == 0 && computeInt == 0 && storage == 0 && + subgroup == 0 && dotProduct == 0 && coopMatrix == 0; + } +}; + +// Maps the given SPIR-V capability to the corresponding device query feature +// and updates features. +// +// Note that the device queries used here should match the ones used in +// iree_hal_vulkan_get_device_properties() on the runtime side. +LogicalResult mapToDeviceQuery(IREE::HAL::ExecutableExportOp entryPoint, + spirv::Capability cap, + KernelFeatures &features) { + switch (cap) { + case spirv::Capability::Shader: + // The shader capability is the root capability for graphics APIs. + // So just ignore. + return success(); + + //===-------------------------------------------------------------------===// + // Compute capabilities + case spirv::Capability::Float16: + features.computeFloat |= 0b01; + return success(); + case spirv::Capability::Float64: + features.computeFloat |= 0b10; + return success(); + case spirv::Capability::Int8: + features.computeInt |= 0b001; + return success(); + case spirv::Capability::Int16: + features.computeInt |= 0b010; + return success(); + case spirv::Capability::Int64: + features.computeInt |= 0b100; + return success(); + + //===-------------------------------------------------------------------===// + // Storage capabilities + case spirv::Capability::UniformAndStorageBuffer8BitAccess: + case spirv::Capability::StorageBuffer8BitAccess: + // These capabilities allow 8-bit types to appear in interface variables of + // a particular storage class. + // So cluster them together. + features.storage |= 0b01; + return success(); + case spirv::Capability::StorageBuffer16BitAccess: + case spirv::Capability::StorageUniform16: + // These capabilities allow 16-bit types to appear in interface variables of + // a particular storage class. + // So cluster them together. + features.storage |= 0b10; + return success(); + + //===-------------------------------------------------------------------===// + // Subgroup capabilities + case spirv::Capability::GroupNonUniform: + // The basic subgroup capability provides access to builtin variables like + // subgroup ID and size. + // * In Vulkan, this is mandated starting v1.1. + // * In Metal, we have it since v2.2. + // So just ignore. + return success(); + case spirv::Capability::GroupNonUniformShuffle: + features.subgroup |= 0b01; + return success(); + case spirv::Capability::GroupNonUniformArithmetic: + features.subgroup |= 0b10; + return success(); + + case spirv::Capability::DotProduct: + case spirv::Capability::DotProductInput4x8Bit: + // We only ever use vector<4xi8> -> i32 variant of dot product right now. + features.dotProduct |= 0b1; + return success(); + + //===-------------------------------------------------------------------===// + // Cooperative matrix capabilities + case spirv::Capability::CooperativeMatrixKHR: { + // Cooperative matrix has many device specific configurations. They are not + // directly reflected in the SPIR-V capabilities. We need to be explicit by + // looking at the chosen configuration. + // Format: "coopmatrix...xx". + auto coopmatType = + entryPoint->getAttrOfType("iree.spirv.coopmatrix.type"); + auto coopmatShape = entryPoint->getAttrOfType( + "iree.spirv.coopmatrix.shape"); + if (!coopmatType || !coopmatShape) + return failure(); + + Type inputType = cast(coopmatType.getValue().front()).getValue(); + Type outputType = cast(coopmatType.getValue().back()).getValue(); + int64_t mSize = coopmatShape.asArrayRef()[0]; + int64_t nSize = coopmatShape.asArrayRef()[1]; + int64_t kSize = coopmatShape.asArrayRef()[2]; + + // We explicitly perform exact match here given that 1) we need to have the + // corresponding query in the runtime, and 2) we are not using a lot of + // configuarations in CodeGen yet. + if (inputType.isF16() && outputType.isF16()) { + if (mSize == 16 && nSize == 16 && kSize == 16) { + features.coopMatrix |= 0b1; + return success(); + } + } + + return success(); + } + + default: + break; + } + return failure(); +} + +// Builds the device query ops using the given builder. +// +// Note that the device queries used here should match the ones used in +// iree_hal_vulkan_device_query_i64() on the runtime side. +void buildDeviceQueryRegion(const KernelFeatures &features, Value device, + Location loc, OpBuilder &builder) { + IntegerType boolType = builder.getI1Type(); + IntegerType i32Type = builder.getI32Type(); + TypedAttr zeroAttr = builder.getZeroAttr(i32Type); + + auto buildQueryOp = [&](const char *key, uint32_t value, Value result) { + auto queryOp = builder.create( + loc, boolType, i32Type, device, builder.getStringAttr("hal.dispatch"), + builder.getStringAttr(key), zeroAttr); + auto zero = builder.create(loc, 0, 32); + auto val = builder.create(loc, value, 32); + auto andOp = builder.create(loc, queryOp.getValue(), val); + auto cmpOp = builder.create(loc, arith::CmpIPredicate::ne, + andOp, zero); + // Verify that 1) the query succeeds and 2) the capability is supported. + auto ok = builder.create(loc, queryOp.getOk(), cmpOp); + return builder.create(loc, result, ok).getResult(); + }; + + Value result = builder.create(loc, true, 1); + if (features.computeFloat) { + result = buildQueryOp("compute.f", features.computeFloat, result); + } + if (features.computeInt) { + result = buildQueryOp("compute.i", features.computeInt, result); + } + if (features.storage) { + result = buildQueryOp("storage", features.storage, result); + } + if (features.subgroup) { + result = buildQueryOp("subgroup", features.subgroup, result); + } + if (features.dotProduct) { + result = buildQueryOp("dotprod", features.dotProduct, result); + } + if (features.coopMatrix) { + result = buildQueryOp("coopmatrix", features.coopMatrix, result); + } + builder.create(loc, result); +} + +// Returns the device queries as a list of unique keys. +SmallVector getDeviceQueries(const KernelFeatures &features) { + SmallVector queries; + if (features.computeFloat) { + queries.push_back("compute.f=" + std::to_string(features.computeFloat)); + } + if (features.computeInt) { + queries.push_back("compute.i=" + std::to_string(features.computeInt)); + } + if (features.storage) { + queries.push_back("storage=" + std::to_string(features.storage)); + } + if (features.subgroup) { + queries.push_back("subgroup=" + std::to_string(features.subgroup)); + } + if (features.dotProduct) { + queries.push_back("dotprod=" + std::to_string(features.dotProduct)); + } + if (features.coopMatrix) { + queries.push_back("coopmatrix=" + std::to_string(features.coopMatrix)); + } + return queries; +} + +struct SPIRVMaterializeExecutableConditionsPass final + : SPIRVMaterializeExecutableConditionsBase< + SPIRVMaterializeExecutableConditionsPass> { + void runOnOperation() override { + IREE::HAL::ExecutableVariantOp variantOp = getOperation(); + if (!usesSPIRVCodeGen(variantOp)) + return; + + IREE::HAL::ExecutableTargetAttr executableTarget = variantOp.getTarget(); + DictionaryAttr configuration = executableTarget.getConfiguration(); + auto spirvTarget = configuration.getAs( + spirv::getTargetEnvAttrName()); + + auto exportOps = variantOp.getOps(); + if (!llvm::hasSingleElement(exportOps)) { + variantOp.emitError("expected to contain exactly one export op"); + return signalPassFailure(); + } + IREE::HAL::ExecutableExportOp exportOp = *exportOps.begin(); + + // Map all required SPIR-V capabilities to device queries and unique them. + // Here we only consider capabilities--version/extension is just the spec + // "container" for them; so we can ignore. + KernelFeatures features; + for (spirv::Capability cap : spirvTarget.getCapabilities()) { + if (failed(mapToDeviceQuery(exportOp, cap, features))) { + variantOp.emitError("failed to handle capability ") + << spirv::stringifyCapability(cap); + return signalPassFailure(); + } + } + + OpBuilder builder(variantOp); + + // Build the hal.executable.condition op inside the variant. + if (!features.empty()) { + Value device = variantOp.createConditionOp(builder); + buildDeviceQueryRegion(features, device, device.getLoc(), builder); + } + + // Build a string list of the used queries too--this is useful for attaching + // to the executable target attribute as a unique key for the linking pass. + SmallVector strings = getDeviceQueries(features); + SmallVector queries; + queries.reserve(strings.size() + 1); + queries.push_back(variantOp.getTarget().getBackend().getValue()); + for (const std::string &s : strings) { + queries.push_back(s); + } + + // Drop the fine-grained SPIR-V target and add the course-grained device + // queries as a list. + auto dictKeyValues = llvm::to_vector(llvm::make_filter_range( + configuration.getValue(), [](NamedAttribute attr) { + return attr.getName() != spirv::getTargetEnvAttrName(); + })); + dictKeyValues.emplace_back(builder.getStringAttr("iree.spirv.features"), + builder.getStrArrayAttr(queries)); + variantOp.setTargetAttr(IREE::HAL::ExecutableTargetAttr::get( + executableTarget.getContext(), executableTarget.getBackend(), + executableTarget.getFormat(), + DictionaryAttr::get(configuration.getContext(), dictKeyValues))); + } +}; + +} // namespace + +std::unique_ptr> +createSPIRVMaterializeExecutableConditionsPass() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp index ba8cc0cb1393..849e90e63c2e 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp @@ -12,20 +12,14 @@ #include "iree/compiler/Codegen/SPIRV/Passes.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Passes.h" - -#define DEBUG_TYPE "iree-spirv-select-lowering-strategy-pass" namespace mlir::iree_compiler { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp index a4c966e525a6..a916d2519ae8 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp @@ -27,9 +27,7 @@ #include "iree/compiler/Codegen/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" -#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -40,7 +38,6 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" -#include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -68,17 +65,23 @@ static SmallVector getTargetCooperativeOpSize(linalg::LinalgOp op) { return getTileSizes(op, 3); // For native vector sizes } -constexpr char coopMatShapeAttrName[] = "iree.spirv.coop_mat_shape"; +constexpr char coopMatTypeAttrName[] = "iree.spirv.coopmatrix.type"; +constexpr char coopMatShapeAttrName[] = "iree.spirv.coopmatrix.shape"; -/// Sets the chosen cooperative matrix shape for CodeGen onto the +/// Sets the chosen cooperative matrix type/shape for CodeGen onto the /// hal.executable.export op for the given `funcOp`. -void setSPIRVCooperativeMatrixShape(func::FuncOp funcOp, - ArrayRef shape) { +void setSPIRVCooperativeMatrixInfo(func::FuncOp funcOp, linalg::LinalgOp rootOp, + ArrayRef shape) { auto moduleOp = funcOp->getParentOfType(); auto exportOp = getAllEntryPoints(moduleOp).lookup(funcOp.getName()); Builder b(funcOp.getContext()); exportOp->setAttr(coopMatShapeAttrName, b.getDenseI64ArrayAttr(shape)); + auto inputType = cast(rootOp.getDpsInputs().front().getType()); + auto outputType = cast(rootOp.getDpsInits().front().getType()); + auto elementTypes = b.getTypeArrayAttr( + {inputType.getElementType(), outputType.getElementType()}); + exportOp->setAttr(coopMatTypeAttrName, elementTypes); } /// Returns the chosen cooperative matrix shape for CodeGen from the @@ -357,7 +360,7 @@ class SPIRVTileToCooperativeOpsPass final // given that after tiling and vectorization we won't have the root Linalg // op anymore. SmallVector cooperativeOpSize = getTargetCooperativeOpSize(rootOp); - setSPIRVCooperativeMatrixShape(funcOp, cooperativeOpSize); + setSPIRVCooperativeMatrixInfo(funcOp, rootOp, cooperativeOpSize); SmallVector subgroupCounts = deduceSubgroupCounts(rootOp); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTrimExecutableTargetEnv.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTrimExecutableTargetEnv.cpp index e247d9a199fa..9d07ebcd7d66 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTrimExecutableTargetEnv.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTrimExecutableTargetEnv.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/SPIRV/PassDetail.h" #include "iree/compiler/Codegen/SPIRV/Passes.h" +#include "iree/compiler/Codegen/SPIRV/Utils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -16,23 +17,15 @@ namespace mlir::iree_compiler { namespace { -bool IsSPIRVBasedBackend(StringRef backend) { - return backend.starts_with("vulkan") || backend.starts_with("metal") || - backend.starts_with("webgpu"); -} - struct SPIRVTrimExecutableTargetEnvPass final : SPIRVTrimExecutableTargetEnvBase { void runOnOperation() override { IREE::HAL::ExecutableVariantOp variant = getOperation(); - if (!IsSPIRVBasedBackend(variant.getTarget().getBackend())) { - return; - } - if (variant.getObjects().has_value()) { - // Ignore external executable variants. We need to read spirv.module - // ops to get the deduced minimal list of required capability and - // extension. External source executables won't have any spirv.module - // ops inside. + if (!usesSPIRVCodeGen(variant)) { + // Ignore variants not targeting SPIR-V or external executable variants. + // We need to read spirv.module ops to get the deduced minimal list of + // required capability and extension. External source executables won't + // have any spirv.module ops inside. return; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp index 8fdb97fb37eb..92e15c40a4eb 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp @@ -11,15 +11,28 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/IR/BuiltinAttributes.h" namespace mlir::iree_compiler { +bool usesSPIRVCodeGen(IREE::HAL::ExecutableVariantOp variantOp) { + if (variantOp.getObjects().has_value()) { + // Variants containing external executables do not go through CodeGen. + return false; + } + + DictionaryAttr configuration = variantOp.getTargetAttr().getConfiguration(); + // The spirv.target_env attribute is attached if going down SPIR-V CodeGen + // pipelines. Later we turn spirv.target_env into iree.spirv.features after + // materializing device queries. + return configuration.contains(spirv::getTargetEnvAttrName()) || + configuration.contains("iree.spirv.features"); +} + const char *getSPIRVDistributeAttrName() { return "iree.spirv.distribute_dim"; } spirv::TargetEnvAttr getSPIRVTargetEnvAttr(Operation *op) { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h index b5ddd0ebe1be..275a882afa66 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h @@ -13,6 +13,7 @@ #ifndef IREE_COMPILER_CODEGEN_SPIRV_UTILS_H_ #define IREE_COMPILER_CODEGEN_SPIRV_UTILS_H_ +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" @@ -21,6 +22,9 @@ namespace mlir::iree_compiler { +// Returns true if the given variant op uses SPIR-V CodeGen. +bool usesSPIRVCodeGen(IREE::HAL::ExecutableVariantOp variantOp); + /// Returns the attribute name carrying information about distribution. const char *getSPIRVDistributeAttrName(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel index be2d14a8ce2f..d1d63c2ebccd 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel @@ -52,6 +52,7 @@ iree_lit_test_suite( "lowering_scalar_dispatch.mlir", "lowering_reduction.mlir", "map_memref_storage_class.mlir", + "materialize_executable_conditions.mlir", "pipeline_matmul_cooperative_ops.mlir", "pipeline_matmul_promotion.mlir", "pipeline_matmul_vectorization.mlir", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt index fd9df00cfb35..fdea4975c58b 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt @@ -48,6 +48,7 @@ iree_lit_test_suite( "lowering_reduction.mlir" "lowering_scalar_dispatch.mlir" "map_memref_storage_class.mlir" + "materialize_executable_conditions.mlir" "pipeline_matmul_cooperative_ops.mlir" "pipeline_matmul_promotion.mlir" "pipeline_matmul_vectorization.mlir" diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir index ac1850449299..bd9145eaec7f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir @@ -6,9 +6,7 @@ // // For such case we can link all executables into one, with just one variant. -#vulkan_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> +#vulkan_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]}> #pipeline_layout = #hal.pipeline.layout i32 as "foo" +// CHECK: hal.executable.constant.block(%arg0: !hal.device) -> i32 as "foo" // CHECK-NEXT: = arith.constant 1 // CHECK: hal.executable.export public @dispatch_0 ordinal(0) // CHECK: hal.return %c1, %c1, %c1 @@ -165,11 +163,9 @@ util.initializer { // having one variant containing all entry points needing the same target. #vulkan_target_0 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan"]}> #vulkan_target_1 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan", "subgroup=1"]}> #pipeline_layout = #hal.pipeline.layout i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = arith.andi %value, %c1_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.constant.block(%device: !hal.device) -> i32 as "baz" { %c2 = arith.constant 2 : i32 hal.return %c2 : i32 @@ -236,6 +241,15 @@ hal.executable private @dispatch_2 { } hal.executable private @dispatch_3 { hal.executable.variant @spirv target(#vulkan_target_1) { + hal.executable.condition(%arg0: !hal.device) -> i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = arith.andi %value, %c1_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.export @dispatch_3 ordinal(0) layout(#pipeline_layout) { ^bb0(%arg0: !hal.device) : %c1 = arith.constant 1 : index @@ -262,10 +276,8 @@ func.func @two_target_environments() -> () { return } -// CHECK: #[[TARGET0:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce -// CHECK: #[[TARGET1:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce +// CHECK-DAG: #[[TARGET0:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]} +// CHECK-DAG: #[[TARGET1:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "subgroup=1"]} // CHECK: hal.executable private @link_executables_linked_spirv_0 { // CHECK: hal.executable.variant public @vulkan_spirv_fb target(#[[TARGET0]]) { @@ -289,6 +301,8 @@ func.func @two_target_environments() -> () { // CHECK: } // CHECK: hal.executable private @link_executables_linked_spirv_1 { // CHECK: hal.executable.variant public @vulkan_spirv_fb target(#[[TARGET1]]) { +// CHECK: hal.executable.condition(%arg0: !hal.device) -> i1 +// CHECK-NEXT: hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") // CHECK: hal.executable.constant.block(%arg0: !hal.device) -> i32 as "baz" // CHECK-NEXT: = arith.constant 2 : i32 // CHECK: hal.executable.export public @dispatch_1 ordinal(0) @@ -324,14 +338,11 @@ func.func @two_target_environments() -> () { // same set of target requirements. #vulkan_target_0 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan"]}> #vulkan_target_1 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan", "subgroup=1"]}> #vulkan_target_2 = #hal.executable.target<"vulkan", "vulkan-spirv-fb", { - spirv.target_env = #spirv.target_env<#spirv.vce, - api=Vulkan, Unknown:DiscreteGPU, #spirv.resource_limits<>>}> + iree.spirv.features = ["vulkan", "subgroup=2"]}> #pipeline_layout = #hal.pipeline.layout i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = arith.andi %value, %c1_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.constant.block(%device: !hal.device) -> i32 as "foo" { %c2 = arith.constant 2 : i32 hal.return %c2 : i32 @@ -400,6 +420,15 @@ hal.executable private @dispatch_1 { } } hal.executable.variant @spirv_1 target(#vulkan_target_1) { + hal.executable.condition(%arg0: !hal.device) -> i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = arith.andi %value, %c1_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.constant.block(%device: !hal.device) -> i32 as "baz" { %c4 = arith.constant 4 : i32 hal.return %c4 : i32 @@ -454,6 +483,15 @@ hal.executable private @dispatch_3 { } } hal.executable.variant @spirv_1 target(#vulkan_target_2) { + hal.executable.condition(%arg0: !hal.device) -> i1 { + %ok, %value = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") : i1, i32 = 0 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %0 = arith.andi %value, %c2_i32 : i32 + %1 = arith.cmpi ne, %0, %c0_i32 : i32 + %2 = arith.andi %ok, %1 : i1 + hal.return %2 : i1 + } hal.executable.export @dispatch_3 ordinal(0) layout(#pipeline_layout) { ^bb0(%arg0: !hal.device) : %c1 = arith.constant 1 : index @@ -469,12 +507,9 @@ hal.executable private @dispatch_3 { } } -// CHECK: #[[TARGET0:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce -// CHECK: #[[TARGET1:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce -// CHECK: #[[TARGET2:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", -// CHECK-SAME: #spirv.target_env<#spirv.vce +// CHECK-DAG: #[[TARGET0:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]} +// CHECK-DAG: #[[TARGET1:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "subgroup=1"]} +// CHECK-DAG: #[[TARGET2:.+]] = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "subgroup=2"]} // CHECK: hal.executable private @link_executables_linked_spirv { // CHECK: hal.executable.variant public @vulkan_spirv_fb_0 target(#[[TARGET0]]) { @@ -500,6 +535,10 @@ hal.executable private @dispatch_3 { // CHECK: } // CHECK: } // CHECK: hal.executable.variant public @vulkan_spirv_fb_1 target(#[[TARGET1]]) { +// CHECK: hal.executable.condition(%arg0: !hal.device) -> i1 +// CHECK-NEXT: %{{.+}}, %[[V:.+]] = hal.device.query<%arg0 : !hal.device> key("hal.device.vulkan" :: "subgroup") +// CHECK: %[[TARGET:.+]] = arith.constant 1 : i32 +// CHECK-NEXT: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 // CHECK: hal.executable.constant.block(%arg0: !hal.device) -> i32 as "foo" // CHECK-NEXT: = arith.constant 2 : i32 // CHECK: hal.executable.export public @dispatch_0 ordinal(0) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir new file mode 100644 index 000000000000..c922dd5724f5 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir @@ -0,0 +1,236 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-materialize-executable-conditions)))' --mlir-print-local-scope %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout, + <1, storage_buffer, ReadOnly>, + <2, storage_buffer> + ]> +]> + +hal.executable private @dispatch_executable { + // CHECK-LABEL: hal.executable.variant public @test_assumed_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]}>) + // CHECK-NOT: hal.executable.condition + hal.executable.variant public @test_assumed_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_assumed_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_assumed_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_assumed_capabilities + spirv.ExecutionMode @test_assumed_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_subgroup_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "subgroup=3"]}>) + // CHECK-NEXT: hal.executable.condition(%[[DEV:.+]]: !hal.device) -> i1 { + // CHECK-NEXT: %[[T:.+]] = arith.constant true + // CHECK-NEXT: %[[OK:.+]], %[[V:.+]] = hal.device.query<%[[DEV]] : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "subgroup") : i1, i32 = 0 : i32 + // CHECK-NEXT: %[[ZERO:.+]] = arith.constant 0 : i32 + // CHECK-NEXT: %[[TARGET:.+]] = arith.constant 3 : i32 + // CHECK-NEXT: %[[CHECK:.+]] = arith.andi %[[V]], %[[TARGET]] : i32 + // CHECK-NEXT: %[[CMP:.+]] = arith.cmpi ne, %[[CHECK]], %[[ZERO]] : i32 + // CHECK-NEXT: %[[AND:.+]] = arith.andi %[[OK]], %[[CMP]] : i1 + // CHECK-NEXT: %[[RESULT:.+]] = arith.andi %[[T]], %[[AND]] : i1 + // CHECK-NEXT: hal.return %[[RESULT]] : i1 + // CHECK-NEXT: } + hal.executable.variant public @test_subgroup_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_subgroup_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_subgroup_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_subgroup_capabilities + spirv.ExecutionMode @test_subgroup_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_8bit_storage_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "storage=1"]}>) + // CHECK-NEXT: hal.executable.condition(%[[DEV:.+]]: !hal.device) -> i1 { + // CHECK-NEXT: %[[T:.+]] = arith.constant true + // CHECK-NEXT: %[[OK:.+]], %[[V:.+]] = hal.device.query<%[[DEV]] : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "storage") : i1, i32 = 0 : i32 + // CHECK-NEXT: %[[ZERO:.+]] = arith.constant 0 : i32 + // CHECK-NEXT: %[[TARGET:.+]] = arith.constant 1 : i32 + // CHECK-NEXT: %[[CHECK:.+]] = arith.andi %[[V]], %[[TARGET]] : i32 + // CHECK-NEXT: %[[CMP:.+]] = arith.cmpi ne, %[[CHECK]], %[[ZERO]] : i32 + // CHECK-NEXT: %[[AND:.+]] = arith.andi %[[OK]], %[[CMP]] : i1 + // CHECK-NEXT: %[[RESULT:.+]] = arith.andi %[[T]], %[[AND]] : i1 + // CHECK-NEXT: hal.return %[[RESULT]] : i1 + // CHECK-NEXT: } + hal.executable.variant public @test_8bit_storage_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_8bit_storage_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires + #spirv.vce { + spirv.func @test_8bit_storage_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_8bit_storage_capabilities + spirv.ExecutionMode @test_8bit_storage_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_16bit_storage_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "storage=2"]}>) + // CHECK-NEXT: hal.executable.condition(%[[DEV:.+]]: !hal.device) -> i1 { + // CHECK-NEXT: %[[T:.+]] = arith.constant true + // CHECK-NEXT: %[[OK:.+]], %[[V:.+]] = hal.device.query<%[[DEV]] : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "storage") : i1, i32 = 0 : i32 + // CHECK-NEXT: %[[ZERO:.+]] = arith.constant 0 : i32 + // CHECK-NEXT: %[[TARGET:.+]] = arith.constant 2 : i32 + // CHECK-NEXT: %[[CHECK:.+]] = arith.andi %[[V]], %[[TARGET]] : i32 + // CHECK-NEXT: %[[CMP:.+]] = arith.cmpi ne, %[[CHECK]], %[[ZERO]] : i32 + // CHECK-NEXT: %[[AND:.+]] = arith.andi %[[OK]], %[[CMP]] : i1 + // CHECK-NEXT: %[[RESULT:.+]] = arith.andi %[[T]], %[[AND]] : i1 + // CHECK-NEXT: hal.return %[[RESULT]] : i1 + // CHECK-NEXT: } + hal.executable.variant public @test_16bit_storage_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_16bit_storage_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires + #spirv.vce { + spirv.func @test_16bit_storage_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_16bit_storage_capabilities + spirv.ExecutionMode @test_16bit_storage_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_int_compute_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "compute.i=7"]}>) + // CHECK: %{{.+}}, %[[V:.+]] = hal.device.query<%{{.+}} : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "compute.i") : i1, i32 = 0 : i32 + // CHECK: %[[TARGET:.+]] = arith.constant 7 : i32 + // CHECK: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 + hal.executable.variant public @test_int_compute_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_int_compute_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_int_compute_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_int_compute_capabilities + spirv.ExecutionMode @test_int_compute_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_float_compute_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "compute.f=3"]}>) + // CHECK: %{{.+}}, %[[V:.+]] = hal.device.query<%{{.+}} : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "compute.f") : i1, i32 = 0 : i32 + // CHECK: %[[TARGET:.+]] = arith.constant 3 : i32 + // CHECK: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 + hal.executable.variant public @test_float_compute_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_float_compute_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_float_compute_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_float_compute_capabilities + spirv.ExecutionMode @test_float_compute_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_dot_product_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "dotprod=1"]}>) + // CHECK: %{{.+}}, %[[V:.+]] = hal.device.query<%{{.+}} : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "dotprod") : i1, i32 = 0 : i32 + // CHECK: %[[TARGET:.+]] = arith.constant 1 : i32 + // CHECK: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 + hal.executable.variant public @test_dot_product_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_dot_product_capabilities ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_dot_product_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_dot_product_capabilities + spirv.ExecutionMode @test_dot_product_capabilities "LocalSize", 64, 1, 1 + } + } + } + + // CHECK-LABEL: hal.executable.variant public @test_cooperative_matrix_capabilities + // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan", "coopmatrix=1"]}>) + // CHECK: %{{.+}}, %[[V:.+]] = hal.device.query<%{{.+}} : !hal.device> + // CHECK-SAME: key("hal.dispatch" :: "coopmatrix") : i1, i32 = 0 : i32 + // CHECK: %[[TARGET:.+]] = arith.constant 1 : i32 + // CHECK: %{{.+}} = arith.andi %[[V]], %[[TARGET]] : i32 + hal.executable.variant public @test_cooperative_matrix_capabilities target( + #hal.executable.target<"vulkan", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + }> + ) { + hal.executable.export public @test_cooperative_matrix_capabilities ordinal(0) layout(#pipeline_layout) attributes { + iree.spirv.coopmatrix.shape = array, iree.spirv.coopmatrix.type = [f16, f16] + } { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @test_cooperative_matrix_capabilities() "None" { spirv.Return } + spirv.EntryPoint "GLCompute" @test_cooperative_matrix_capabilities + spirv.ExecutionMode @test_cooperative_matrix_capabilities "LocalSize", 64, 1, 1 + } + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel index dd5dc93c46cd..823a0956b917 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel @@ -38,6 +38,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface", "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Utils", "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", "//llvm-external-projects/iree-dialects:IREELinalgExtPasses", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt index d73ef76dba30..1e7fa2fc7f06 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt @@ -57,6 +57,7 @@ iree_cc_library( iree::compiler::Codegen::Interfaces::UKernelOpInterface iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp index c2df59e9fccf..9fd7492a9c4f 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp @@ -7,6 +7,8 @@ #include "iree/compiler/Codegen/Utils/LinkingUtils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Utils/EquivalenceUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" @@ -235,6 +237,29 @@ LogicalResult linkExecutablesInto( {SymbolRefAttr::get(linkedTargetOp)}); symbolReplacements.variantRefs[oldVariantRefAttr] = newVariantRefAttr; + // Move the condition op too. We need to make sure all variant's condition + // op has the same content. + auto targetConditionOps = + linkedTargetOp.getOps(); + if (auto sourceCoditionOp = variantOp.getConditionOp()) { + if (targetConditionOps.empty()) { + sourceCoditionOp->moveBefore( + &*linkedTargetBuilder.getInsertionPoint()); + } else { + assert(llvm::hasSingleElement(targetConditionOps)); + IREE::HAL::ExecutableConditionOp referenceOp = + *targetConditionOps.begin(); + if (!isStructurallyEquivalentTo(*sourceCoditionOp.getOperation(), + *referenceOp.getOperation())) { + return variantOp.emitError("contains incompatible condition op"); + } + } + } else { + if (!targetConditionOps.empty()) { + return variantOp.emitError("should contain a condition op"); + } + } + // Move any constant blocks that need to be preserved for future host // translation. There may be duplicates provided but they'll be cleaned // up in future passes. diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 0eb86d2baf09..25d26c1e6295 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -1220,6 +1220,18 @@ DenseMap ExecutableVariantOp::gatherConstantOrdinals() { return map; } +Value ExecutableVariantOp::createConditionOp(OpBuilder &builder) { + assert(!getConditionOp() && "condition op already exists"); + + builder.setInsertionPointToStart(&getRegion().front()); + auto conditionOp = builder.create(getLoc()); + Block *entryPoint = conditionOp.addEntryBlock(); + Value device = entryPoint->getArgument(0); + + builder.setInsertionPointToStart(entryPoint); + return device; +} + Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) { // Base case dependent on target information. // TODO(multi-device): condition on device target ID and other queries that diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 91bda119787e..164ffa28ea2c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -2162,6 +2162,11 @@ def HAL_ExecutableVariantOp : HAL_Op<"executable.variant", [ // blocks inside the variant. DenseMap gatherConstantOrdinals(); + // Creates the new `hal.executable.condition` op in this variant op and sets + // the insertion point of the provided builder to the beginning of the new + // region. + Value createConditionOp(OpBuilder &builder); + // Returns an i1 indicating whether this variant should be selected. Value buildCondition(Value device, OpBuilder &builder); }]; diff --git a/runtime/src/iree/hal/drivers/vulkan/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/vulkan/dynamic_symbol_tables.h index 8fb4403c28e9..6497a44c2aa7 100644 --- a/runtime/src/iree/hal/drivers/vulkan/dynamic_symbol_tables.h +++ b/runtime/src/iree/hal/drivers/vulkan/dynamic_symbol_tables.h @@ -319,7 +319,7 @@ namespace vulkan { INS_PFN(EXCLUDED, vkGetDisplayPlaneCapabilitiesKHR) \ INS_PFN(EXCLUDED, vkGetDisplayPlaneSupportedDisplaysKHR) \ INS_PFN(OPTIONAL, vkGetPhysicalDeviceCalibrateableTimeDomainsEXT) \ - INS_PFN(EXCLUDED, vkGetPhysicalDeviceCooperativeMatrixPropertiesNV) \ + INS_PFN(OPTIONAL, vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR) \ INS_PFN(EXCLUDED, vkGetPhysicalDeviceDisplayPlaneProperties2KHR) \ INS_PFN(EXCLUDED, vkGetPhysicalDeviceDisplayPlanePropertiesKHR) \ INS_PFN(EXCLUDED, vkGetPhysicalDeviceDisplayProperties2KHR) \ diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h index 562b6513c498..4c12179c1894 100644 --- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h +++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h @@ -106,4 +106,40 @@ iree_hal_vulkan_device_extensions_t iree_hal_vulkan_infer_enabled_device_extensions( const iree::hal::vulkan::DynamicSymbols* device_syms); +// Struct for supported device properties. +// +// Note that the fields used here should match the ones used in KernelFeatures +// on the compiler side. +typedef struct iree_hal_vulkan_device_properties_t { + // Floating-point compute related feature bitfield: + // * 0b01: f16 + // * 0b10: f64 + // Note that f32 is assumed to always exist and does not appear in this + // bitfield. + uint32_t compute_float : 8; + // Integer compute related feature bitfield: + // * 0b001: i8 + // * 0b010: i16 + // * 0b100: i64 + // Note that i32 or i1 is assumed to always exist and does not appear in + // this bitfield. + uint32_t compute_int : 8; + // Storage bitwidth requirement bitfiled: + // * 0b01: 8-bit + // * 0b10: 16-bit + uint32_t storage : 8; + // Subgroup operation requirement bitfield: + // * 0b01: subgroup shuffle operations + // * 0b10: subgroup arithmetic operations + uint32_t subgroup : 8; + // Dot product operation requirement bitfield: + // ("dotprod..") + // * 0b01: dotprod.4xi8.i32 + uint32_t dot_product : 8; + // Cooperative matrix requirement bitfield: + // ("coopmatrix...xx") + // * 0b01: coopmatrix.f16.f16.16x16x16 + uint32_t cooperative_matrix : 8; +} iree_hal_vulkan_iree_hal_vulkan_device_properties_t; + #endif // IREE_HAL_DRIVERS_VULKAN_EXTENSIBILITY_UTIL_H_ diff --git a/runtime/src/iree/hal/drivers/vulkan/handle_util.h b/runtime/src/iree/hal/drivers/vulkan/handle_util.h index 9cf0349e04ee..c85bcb624ec4 100644 --- a/runtime/src/iree/hal/drivers/vulkan/handle_util.h +++ b/runtime/src/iree/hal/drivers/vulkan/handle_util.h @@ -42,12 +42,14 @@ class VkDeviceHandle : public RefObject { VkDeviceHandle(DynamicSymbols* syms, VkPhysicalDevice physical_device, iree_hal_vulkan_features_t enabled_features, iree_hal_vulkan_device_extensions_t enabled_extensions, + iree_hal_vulkan_device_properties_t supported_properties, bool owns_device, iree_allocator_t host_allocator, const VkAllocationCallbacks* allocator = nullptr) : physical_device_(physical_device), syms_(add_ref(syms)), enabled_features_(enabled_features), enabled_extensions_(enabled_extensions), + supported_properties_(supported_properties), owns_device_(owns_device), allocator_(allocator), host_allocator_(host_allocator) {} @@ -62,6 +64,7 @@ class VkDeviceHandle : public RefObject { value_(exchange(other.value_, static_cast(VK_NULL_HANDLE))), syms_(std::move(other.syms_)), enabled_extensions_(other.enabled_extensions_), + supported_properties_(other.supported_properties_), owns_device_(other.owns_device_), allocator_(other.allocator_), host_allocator_(other.host_allocator_) {} @@ -93,12 +96,17 @@ class VkDeviceHandle : public RefObject { return enabled_extensions_; } + const iree_hal_vulkan_device_properties_t& supported_properties() const { + return supported_properties_; + } + private: VkPhysicalDevice physical_device_ = VK_NULL_HANDLE; VkDevice value_ = VK_NULL_HANDLE; ref_ptr syms_; iree_hal_vulkan_features_t enabled_features_; iree_hal_vulkan_device_extensions_t enabled_extensions_; + iree_hal_vulkan_device_properties_t supported_properties_; bool owns_device_; const VkAllocationCallbacks* allocator_ = nullptr; iree_allocator_t host_allocator_; diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc index faffd5a1dc30..41c2aac07b37 100644 --- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc +++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc @@ -498,6 +498,8 @@ typedef struct iree_hal_vulkan_device_t { iree_hal_vulkan_device_flags_t flags; // Which optional extensions are active and available on the device. iree_hal_vulkan_device_extensions_t device_extensions; + // Device properties for various optional features. + iree_hal_vulkan_device_properties_t device_properties; VkInstance instance; VkPhysicalDevice physical_device; @@ -690,6 +692,7 @@ static iree_status_t iree_hal_vulkan_device_create_internal( const iree_hal_vulkan_device_options_t* options, VkInstance instance, VkPhysicalDevice physical_device, VkDeviceHandle* logical_device, const iree_hal_vulkan_device_extensions_t* device_extensions, + const iree_hal_vulkan_device_properties_t* device_properties, const iree_hal_vulkan_queue_set_t* compute_queue_set, const iree_hal_vulkan_queue_set_t* transfer_queue_set, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { @@ -721,6 +724,7 @@ static iree_status_t iree_hal_vulkan_device_create_internal( device->flags = options->flags; device->device_extensions = *device_extensions; + device->device_properties = *device_properties; device->instance = instance; device->physical_device = physical_device; device->logical_device = logical_device; @@ -846,6 +850,160 @@ static iree_status_t iree_hal_vulkan_device_query_extensibility_set( return iree_ok_status(); } +static iree_status_t iree_hal_vulkan_get_device_properties( + DynamicSymbols* instance_syms, VkPhysicalDevice physical_device, + iree_hal_vulkan_device_properties_t* device_properties) { + memset(device_properties, 0, sizeof(*device_properties)); + + VkPhysicalDeviceFeatures2 physical_device_features; + memset(&physical_device_features, 0, sizeof(physical_device_features)); + physical_device_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + + // + Shader float16 and int8 features. + VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_features; + memset(&shader_float16_int8_features, 0, + sizeof(shader_float16_int8_features)); + shader_float16_int8_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES; + shader_float16_int8_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &shader_float16_int8_features; + + // + Shader 8 bit storage features. + VkPhysicalDevice8BitStorageFeatures supported_8bit_storage_features; + memset(&supported_8bit_storage_features, 0, + sizeof(supported_8bit_storage_features)); + supported_8bit_storage_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES; + supported_8bit_storage_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &supported_8bit_storage_features; + + // + Shader 16 bit storage features. + VkPhysicalDevice16BitStorageFeatures supported_16bit_storage_features; + memset(&supported_16bit_storage_features, 0, + sizeof(supported_16bit_storage_features)); + supported_16bit_storage_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES; + supported_16bit_storage_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &supported_16bit_storage_features; + + // + Shader integer dot product features. + VkPhysicalDeviceShaderIntegerDotProductFeatures dot_product_features; + memset(&dot_product_features, 0, sizeof(dot_product_features)); + dot_product_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES; + dot_product_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &dot_product_features; + + // + Cooperative matrix features. + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coop_matrix_features; + memset(&coop_matrix_features, 0, sizeof(coop_matrix_features)); + coop_matrix_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coop_matrix_features.pNext = physical_device_features.pNext; + physical_device_features.pNext = &coop_matrix_features; + + instance_syms->vkGetPhysicalDeviceFeatures2(physical_device, + &physical_device_features); + + VkPhysicalDeviceProperties2 physical_device_properties; + memset(&physical_device_properties, 0, sizeof(physical_device_properties)); + physical_device_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; + physical_device_properties.pNext = NULL; + + // + Subgroup properties. + VkPhysicalDeviceSubgroupProperties subgroup_properties; + memset(&subgroup_properties, 0, sizeof(subgroup_properties)); + subgroup_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + subgroup_properties.pNext = physical_device_properties.pNext; + physical_device_properties.pNext = &subgroup_properties; + + // + Shader integer dot product properties. + VkPhysicalDeviceShaderIntegerDotProductProperties dot_product_properties; + memset(&dot_product_properties, 0, sizeof(dot_product_properties)); + dot_product_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_PROPERTIES; + dot_product_properties.pNext = physical_device_properties.pNext; + physical_device_properties.pNext = &dot_product_properties; + + instance_syms->vkGetPhysicalDeviceProperties2(physical_device, + &physical_device_properties); + + if (shader_float16_int8_features.shaderFloat16) { + device_properties->compute_float |= 0x1u; + } + if (physical_device_features.features.shaderFloat64) { + device_properties->compute_float |= 0x2u; + } + if (shader_float16_int8_features.shaderInt8) { + device_properties->compute_int |= 0x1u; + } + if (physical_device_features.features.shaderInt16) { + device_properties->compute_int |= 0x2u; + } + if (physical_device_features.features.shaderInt64) { + device_properties->compute_int |= 0x4u; + } + if (supported_8bit_storage_features.storageBuffer8BitAccess && + supported_8bit_storage_features.uniformAndStorageBuffer8BitAccess) { + device_properties->storage |= 0x1u; + } + if (supported_16bit_storage_features.storageBuffer16BitAccess && + supported_16bit_storage_features.uniformAndStorageBuffer16BitAccess) { + device_properties->storage |= 0x2u; + } + + if (iree_all_bits_set(subgroup_properties.supportedOperations, + VK_SUBGROUP_FEATURE_SHUFFLE_BIT)) { + device_properties->subgroup |= 0x1u; + } + if (iree_all_bits_set(subgroup_properties.supportedOperations, + VK_SUBGROUP_FEATURE_ARITHMETIC_BIT)) { + device_properties->subgroup |= 0x2u; + } + + if (dot_product_features.shaderIntegerDotProduct && + dot_product_properties.integerDotProduct8BitUnsignedAccelerated && + dot_product_properties.integerDotProduct8BitSignedAccelerated && + dot_product_properties.integerDotProduct8BitMixedSignednessAccelerated && + dot_product_properties + .integerDotProductAccumulatingSaturating8BitUnsignedAccelerated && + dot_product_properties + .integerDotProductAccumulatingSaturating8BitSignedAccelerated && + dot_product_properties + .integerDotProductAccumulatingSaturating8BitMixedSignednessAccelerated) { + device_properties->dot_product |= 0x1u; + } + + if (coop_matrix_features.cooperativeMatrix && + instance_syms->vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR) { + uint32_t count = 0; + IREE_RETURN_IF_ERROR(VK_RESULT_TO_STATUS( + instance_syms->vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR( + physical_device, &count, NULL))); + VkCooperativeMatrixPropertiesKHR* properties = + (VkCooperativeMatrixPropertiesKHR*)iree_alloca( + count * sizeof(VkCooperativeMatrixPropertiesKHR)); + IREE_RETURN_IF_ERROR(VK_RESULT_TO_STATUS( + instance_syms->vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR( + physical_device, &count, properties))); + for (uint32_t i = 0; i < count; ++i) { + VkCooperativeMatrixPropertiesKHR* p = properties + i; + if (p->AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + p->BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + if (p->CType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + if (p->MSize == 16 && p->NSize == 16 && p->KSize == 16) { + device_properties->cooperative_matrix |= 0x1u; + } + } + } + } + } + + return iree_ok_status(); +} + iree_status_t iree_hal_vulkan_device_create( iree_hal_driver_t* driver, iree_string_view_t identifier, iree_hal_vulkan_features_t requested_features, @@ -966,6 +1124,15 @@ iree_status_t iree_hal_vulkan_device_create( available_shader_float16_int8_features.pNext = available_features2.pNext; available_features2.pNext = &available_shader_float16_int8_features; + // + Subgroup matrix features. + VkPhysicalDeviceSubgroupProperties available_subgroup_properties; + memset(&available_subgroup_properties, 0, + sizeof(available_subgroup_properties)); + available_subgroup_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + available_subgroup_properties.pNext = available_features2.pNext; + available_features2.pNext = &available_subgroup_properties; + // + Cooperative matrix features. VkPhysicalDeviceCooperativeMatrixFeaturesKHR available_coop_matrix_features; memset(&available_coop_matrix_features, 0, @@ -1094,15 +1261,19 @@ iree_status_t iree_hal_vulkan_device_create( enabled_features2.pNext = &available_shader_float16_int8_features; } - // Enable all available coop matrix features. + // Enable all available cooperative matrix features. if (enabled_device_extensions.cooperative_matrix) { available_coop_matrix_features.pNext = enabled_features2.pNext; enabled_features2.pNext = &available_coop_matrix_features; } + iree_hal_vulkan_device_properties_t device_properties; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_get_device_properties( + instance_syms, physical_device, &device_properties)); + auto logical_device = new VkDeviceHandle( instance_syms, physical_device, enabled_features, - enabled_device_extensions, + enabled_device_extensions, device_properties, /*owns_device=*/true, host_allocator, /*allocator=*/NULL); iree_status_t status = VK_RESULT_TO_STATUS( @@ -1129,7 +1300,8 @@ iree_status_t iree_hal_vulkan_device_create( status = iree_hal_vulkan_device_create_internal( driver, identifier, enabled_features, options, instance, physical_device, logical_device, &enabled_device_extensions, - &compute_queue_set, &transfer_queue_set, host_allocator, out_device); + &device_properties, &compute_queue_set, &transfer_queue_set, + host_allocator, out_device); } logical_device->ReleaseReference(); @@ -1169,6 +1341,11 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_wrap_device( iree_hal_vulkan_device_extensions_t enabled_device_extensions = iree_hal_vulkan_infer_enabled_device_extensions(device_syms.get()); + // We can still retrieve the correct device properties though. + iree_hal_vulkan_device_properties_t device_properties; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_get_device_properties( + device_syms.get(), physical_device, &device_properties)); + iree_hal_vulkan_features_t enabled_features = 0; #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE enabled_features |= IREE_HAL_VULKAN_FEATURE_ENABLE_TRACING; @@ -1177,7 +1354,7 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_wrap_device( // Wrap the provided VkDevice with a VkDeviceHandle for use within the HAL. auto logical_device_handle = new VkDeviceHandle( device_syms.get(), physical_device, enabled_features, - enabled_device_extensions, + enabled_device_extensions, device_properties, /*owns_device=*/false, host_allocator, /*allocator=*/NULL); *logical_device_handle->mutable_value() = logical_device; @@ -1185,7 +1362,8 @@ IREE_API_EXPORT iree_status_t iree_hal_vulkan_wrap_device( iree_status_t status = iree_hal_vulkan_device_create_internal( /*driver=*/NULL, identifier, enabled_features, options, instance, physical_device, logical_device_handle, &enabled_device_extensions, - compute_queue_set, transfer_queue_set, host_allocator, out_device); + &device_properties, compute_queue_set, transfer_queue_set, host_allocator, + out_device); logical_device_handle->ReleaseReference(); return status; @@ -1238,14 +1416,13 @@ static iree_status_t iree_hal_vulkan_device_query_i64( iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); *out_value = 0; - if (iree_string_view_equal(category, - iree_make_cstring_view("hal.executable.format"))) { - if (iree_string_view_equal(key, - iree_make_cstring_view("vulkan-spirv-fb"))) { + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { + if (iree_string_view_equal(key, IREE_SV("vulkan-spirv-fb"))) { // Base SPIR-V always supported. *out_value = 1; - } else if (iree_string_view_equal( - key, iree_make_cstring_view("vulkan-spirv-fb-ptr"))) { + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("vulkan-spirv-fb-ptr"))) { // SPIR-V with device addresses is optionally supported based on whether // we have device feature support. *out_value = iree_all_bits_set( @@ -1253,8 +1430,38 @@ static iree_status_t iree_hal_vulkan_device_query_i64( IREE_HAL_VULKAN_FEATURE_ENABLE_BUFFER_DEVICE_ADDRESSES) ? 1 : 0; + return iree_ok_status(); + } + } + + // Note that the device queries used here should match the ones used in + // buildDeviceQueryRegion() on the compiler side. + if (iree_string_view_equal(category, IREE_SV("hal.dispatch"))) { + if (iree_string_view_equal(key, IREE_SV("compute.f"))) { + *out_value = device->logical_device->supported_properties().compute_float; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("compute.i"))) { + *out_value = device->logical_device->supported_properties().compute_int; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("storage"))) { + *out_value = device->logical_device->supported_properties().storage; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("subgroup"))) { + *out_value = device->logical_device->supported_properties().subgroup; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("dotprod"))) { + *out_value = device->logical_device->supported_properties().dot_product; + return iree_ok_status(); + } + if (iree_string_view_equal(key, IREE_SV("coopmatrix"))) { + *out_value = + device->logical_device->supported_properties().cooperative_matrix; + return iree_ok_status(); } - return iree_ok_status(); } return iree_make_status(