Skip to content

Commit

Permalink
[spirv][vulkan] Enable device query generation and execution (#15977)
Browse files Browse the repository at this point in the history
This commit adds a pass to materialize executable required SPIR-V
capabilities into proper device queries inside the associated
hal.executable.condition ops. Linking is updated accordingly to
unique and preseve the feature checks. The Vulkan HAL driver
is updated accordingly to probe the implementation and match
against the device queries.

Fixes #15786
  • Loading branch information
antiagainst authored Jan 8, 2024
1 parent b55ba25 commit 852684a
Show file tree
Hide file tree
Showing 25 changed files with 1,000 additions and 86 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ iree_compiler_cc_library(
"SPIRVLinkExecutables.cpp",
"SPIRVLowerExecutableTargetPass.cpp",
"SPIRVMapMemRefStorageClass.cpp",
"SPIRVMaterializeExecutableConditions.cpp",
"SPIRVSelectLoweringStrategy.cpp",
"SPIRVTile.cpp",
"SPIRVTileAndDistribute.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ iree_cc_library(
"SPIRVLinkExecutables.cpp"
"SPIRVLowerExecutableTargetPass.cpp"
"SPIRVMapMemRefStorageClass.cpp"
"SPIRVMaterializeExecutableConditions.cpp"
"SPIRVSelectLoweringStrategy.cpp"
"SPIRVTile.cpp"
"SPIRVTileAndDistribute.cpp"
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,16 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath) {
// NOTE: this runs on the top-level program module containing all hal.executable
// ops.
void buildSPIRVLinkingPassPipeline(OpPassManager &passManager) {
auto &nestedExecutablePM = passManager.nest<IREE::HAL::ExecutableOp>();
// Trim the allowed target environment (version/capability/extension/etc.) to
// the minimal requirement needed by compiled spirv.module ops. This helps to
// increase the chance of linking different variant ops together.
nestedExecutablePM.addNestedPass<IREE::HAL::ExecutableVariantOp>(
createSPIRVTrimExecutableTargetEnvPass());
// Materialize the minimal required target environment into proper device
// queries to execute in the runtime.
nestedExecutablePM.addNestedPass<IREE::HAL::ExecutableVariantOp>(
createSPIRVMaterializeExecutableConditionsPass());
// Link together executables. This may produce some IR duplication.
passManager.addPass(createSPIRVLinkExecutablesPass());

Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ createSPIRVLowerExecutableTargetPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createSPIRVMapMemRefStorageClassPass();

/// Pass to materialize SPIR-V target requirements of hal.exectuable.variant ops
/// into hal.executable.condition regions.
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
createSPIRVMaterializeExecutableConditionsPass();

/// Pass to tile and distribute Linalg ops with buffer semantics to
/// invocations.
std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVTileAndDistributePass();
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def SPIRVMapMemRefStorageClass :
let constructor = "mlir::iree_compiler::createSPIRVMapMemRefStorageClassPass()";
}

def SPIRVMaterializeExecutableConditions :
Pass<"iree-spirv-materialize-executable-conditions",
"mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
let summary = "Materialize SPIR-V target requirements of hal.exectuable.variant "
"ops into hal.executable.condition regions";
let constructor =
"mlir::iree_compiler::createSPIRVMaterializeExecutableConditionsPass()";
}

def SPIRVSelectLoweringStrategy :
Pass<"iree-spirv-select-lowering-strategy-pass",
"mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
Expand Down
29 changes: 9 additions & 20 deletions compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,34 @@

#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Utils/LinkingUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Utils/ModuleUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Pass/Pass.h"

#define DEBUG_TYPE "iree-spirv-link-executable"

namespace mlir::iree_compiler {

namespace IREE::HAL {
// Compares two ExecutableTargetAttr according to the order of used SPIR-V
// capabilities.
// Compares two ExecutableTargetAttr according to the alphabetical order of used
// SPIR-V features.
//
// Note that this is a very specific ordering per the needs of this pass--we
// guarantee that input ExectuableTargetAttr only differ w.r.t. their used
// SPIR-V features, and we want a deterministic order when mutating the IR.
bool operator<(const ExecutableTargetAttr &a, const ExecutableTargetAttr &b) {
auto aTarget = a.getConfiguration().getAs<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName());
auto bTarget = b.getConfiguration().getAs<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName());
auto aFeatures = aTarget.getCapabilitiesAttr();
auto bFeatures = bTarget.getCapabilitiesAttr();
auto aFeatures = a.getConfiguration().getAs<ArrayAttr>("iree.spirv.features");
auto bFeatures = b.getConfiguration().getAs<ArrayAttr>("iree.spirv.features");
for (unsigned i = 0; i < std::min(aFeatures.size(), bFeatures.size()); ++i) {
if (aFeatures[i] != bFeatures[i]) {
return cast<IntegerAttr>(aFeatures[i]).getInt() <
cast<IntegerAttr>(bFeatures[i]).getInt();
return cast<StringAttr>(aFeatures[i]).getValue() <
cast<StringAttr>(bFeatures[i]).getValue();
}
}
return aFeatures.size() < bFeatures.size();
Expand All @@ -49,11 +44,6 @@ namespace {

using IREE::HAL::ExecutableTargetAttr;

bool isSPIRVBasedBackend(IREE::HAL::ExecutableVariantOp variantOp) {
return variantOp.getTargetAttr().getConfiguration().contains(
spirv::getTargetEnvAttrName());
}

struct SPIRVLinkExecutablesPass final
: SPIRVLinkExecutablesBase<SPIRVLinkExecutablesPass> {
void runOnOperation() override {
Expand Down Expand Up @@ -104,9 +94,8 @@ struct SPIRVLinkExecutablesPass final
// sort as the unique key.
currentTargets.clear();
for (auto variant : executable.getOps<IREE::HAL::ExecutableVariantOp>()) {
ExecutableTargetAttr target = variant.getTarget();
if (isSPIRVBasedBackend(variant)) {
currentTargets.push_back(target);
if (usesSPIRVCodeGen(variant)) {
currentTargets.push_back(variant.getTarget());
}
}
llvm::sort(currentTargets);
Expand Down
Loading

0 comments on commit 852684a

Please sign in to comment.