From 57e81b74c9f6b8116072b1d6fd677200f0bf4665 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Thu, 25 Aug 2022 11:01:10 -0400 Subject: [PATCH] Create a registry for SYCL functions callable by the compiler (#40) This PR creates the infrastructure required to register callable SYCL functions into a "registry" object. To use the new infrastructure: ``` // Initialize the SYCL function registry, this will inject function declarations into the module. const auto ®istry = SYCLFuncRegistry::create(module, builder); ... // Call a SYCL function. Value func = SYCLFuncDescriptor::call(SYCLFuncDescriptor::FuncId::Id1CtorDefault, args, registry, builder, loc); ``` --- .../Conversion/SYCLToLLVM/SYCLFuncRegistry.h | 106 ++++++++++++ .../lib/Conversion/SYCLToLLVM/CMakeLists.txt | 1 + .../SYCLToLLVM/SYCLFuncRegistry.cpp | 159 ++++++++++++++++++ 3 files changed, 266 insertions(+) create mode 100644 mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h create mode 100644 mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp diff --git a/mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h b/mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h new file mode 100644 index 0000000000000..b53d601cd7ff4 --- /dev/null +++ b/mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h @@ -0,0 +1,106 @@ +//===- SYCLFuncRegistry.h - Registry of SYCL Functions ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Declare a registry of SYCL functions callable from the compiler. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H +#define MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace sycl { +class SYCLFuncRegistry; + +/// \class SYCLFuncDescriptor +/// Represents a SYCL function (defined in a registry) that can be called by the +/// compiler. +/// Note: when a new enumerator is added the corresponding SYCLFuncDescriptor +/// needs to be created in SYCLFuncRegistry constructor. +class SYCLFuncDescriptor { + friend class SYCLFuncRegistry; + +public: + /// Enumerates SYCL functions. + // clang-format off + enum class FuncId { + // Member functions for the sycl:id class. + Id1CtorDefault, // sycl::id<1>::id() + Id2CtorDefault, // sycl::id<2>::id() + Id3CtorDefault, // sycl::id<3>::id() + Id1CtorSizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type) + Id2CtorSizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type) + Id3CtorSizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type) + Id1CtorRange, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) + Id2CtorRange, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) + Id3CtorRange, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) + Id1CtorItem, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) + Id2CtorItem, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) + Id3CtorItem, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) + + // Member functions for ..TODO.. + }; + // clang-format on + + // Call the SYCL constructor identified by \p id with the given \p args. + static Value call(FuncId id, ArrayRef args, + const SYCLFuncRegistry ®istry, OpBuilder &b, + Location loc); + +private: + /// Private constructor: only available to 'SYCLFuncRegistry'. + SYCLFuncDescriptor(FuncId id, StringRef name, Type outputTy, + ArrayRef argTys) + : id(id), name(name), outputTy(outputTy), + argTys(argTys.begin(), argTys.end()) {} + + // Inject the declaration for this function into the module. + void declareFunction(ModuleOp &module, OpBuilder &b); + +private: + FuncId id; // unique identifier for a SYCL function + StringRef name; // SYCL function name + Type outputTy; // SYCL function output type + SmallVector argTys; // SYCL function arguments types + FlatSymbolRefAttr funcRef; // Reference to the SYCL function declaration +}; + +/// \class SYCLFuncRegistry +/// Singleton class representing the set of SYCL functions callable from the +/// compiler. +class SYCLFuncRegistry { +public: + ~SYCLFuncRegistry() { instance = nullptr; } + + static const SYCLFuncRegistry create(ModuleOp &module, OpBuilder &builder); + + const SYCLFuncDescriptor &getFuncDesc(SYCLFuncDescriptor::FuncId id) const { + assert( + (registry.find(id) != registry.end()) && + "function identified by 'id' not found in the SYCL function registry"); + return registry.at(id); + } + +private: + SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder); + + using Registry = std::map; + static SYCLFuncRegistry *instance; + Registry registry; +}; + +} // namespace sycl +} // namespace mlir + +#endif // MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H diff --git a/mlir-sycl/lib/Conversion/SYCLToLLVM/CMakeLists.txt b/mlir-sycl/lib/Conversion/SYCLToLLVM/CMakeLists.txt index b48f109e0d034..ed33a1f0f8eb0 100644 --- a/mlir-sycl/lib/Conversion/SYCLToLLVM/CMakeLists.txt +++ b/mlir-sycl/lib/Conversion/SYCLToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(MLIRSYCLToLLVM + SYCLFuncRegistry.cpp SYCLToLLVM.cpp SYCLToLLVMPass.cpp diff --git a/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp b/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp new file mode 100644 index 0000000000000..3ad10264360b1 --- /dev/null +++ b/mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp @@ -0,0 +1,159 @@ +//===- SYCLFuncRegistry - SYCL functions registry --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implement a registry of SYCL functions callable by the compiler. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "sycl-func-registry" + +using namespace mlir; +using namespace mlir::sycl; + +// TODO: move in LLVMBuilder class when available. +static FlatSymbolRefAttr getOrInsertFuncDecl(ModuleOp module, OpBuilder &b, + StringRef funcName, + Type resultType, + ArrayRef argsTypes, + bool isVarArg = false) { + if (!module.lookupSymbol(funcName)) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(module.getBody()); + LLVM::LLVMFunctionType funcType = + LLVM::LLVMFunctionType::get(resultType, argsTypes, isVarArg); + b.create(module.getLoc(), funcName, funcType); + } + return SymbolRefAttr::get(b.getContext(), funcName); +} + +//===----------------------------------------------------------------------===// +// SYCLFuncDescriptor +//===----------------------------------------------------------------------===// + +void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) { + // TODO: use LLVMBuilder once available. + funcRef = getOrInsertFuncDecl(module, b, name, outputTy, argTys); +} + +Value SYCLFuncDescriptor::call(FuncId id, ArrayRef args, + const SYCLFuncRegistry ®istry, OpBuilder &b, + Location loc) { + SmallVector funcOutputTys; + const SYCLFuncDescriptor &funcDesc = registry.getFuncDesc(id); + if (!funcDesc.outputTy.isa()) + funcOutputTys.emplace_back(funcDesc.outputTy); + + // TODO: generate the call via LLVMBuilder here + // LLVMBuilder builder(b, loc); + // return builder.call(funcDesc.funcRef, ArrayRef(funcOutputsTys), args); + // TODO: we could check here the arguments against the function signature and + // assert if there is a mismatch. + auto callOp = b.create(loc, ArrayRef(funcOutputTys), + funcDesc.funcRef, args); + assert(callOp.getNumResults() == 1 && "expecting a single result"); + + return callOp.getResult(0); +} + +//===----------------------------------------------------------------------===// +// SYCLFuncRegistry +//===----------------------------------------------------------------------===// + +SYCLFuncRegistry *SYCLFuncRegistry::instance = nullptr; + +const SYCLFuncRegistry SYCLFuncRegistry::create( + ModuleOp &module, OpBuilder &builder) { + if (!instance) + instance = new SYCLFuncRegistry(module, builder); + + return *instance; +} + +SYCLFuncRegistry::SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder) + : registry() { + MLIRContext *context = module.getContext(); + auto voidTy = LLVM::LLVMVoidType::get(context); + auto i8Ty = IntegerType::get(context, 8); + auto i8PtrTy = LLVM::LLVMPointerType::get(i8Ty); + auto i8PtrPtrTy = LLVM::LLVMPointerType::get(i8PtrTy); + auto i64Ty = IntegerType::get(context, 64); + auto i64PtrTy = LLVM::LLVMPointerType::get(i64Ty); + + // Construct the SYCL functions descriptors (enum, function name, signature). + // clang-format off + std::vector descriptors = { + // cl::sycl::id<1>::id() + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1CtorDefault, + "_ZN2cl4sycl2idILi1EEC2Ev", voidTy, {i8PtrTy}), + // cl::sycl::id<2>::id() + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2CtorDefault, + "_ZN2cl4sycl2idILi2EEC2Ev", voidTy, {i8PtrTy}), + // cl::sycl::id<3>::id() + SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3CtorDefault, + "_ZN2cl4sycl2idILi3EEC2Ev", voidTy, {i8PtrTy}), + + // cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id1CtorSizeT, + "_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE", + voidTy, {i8PtrTy, i64Ty}), + // cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id2CtorSizeT, + "_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE", + voidTy, {i8PtrTy, i64Ty}), + // cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id3CtorSizeT, + "_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE", + voidTy, {i8PtrTy, i64Ty}), + + // cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id1CtorRange, + "_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm", + voidTy, {i8PtrTy, i64Ty, i64Ty}), + // cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id2CtorRange, + "_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm", + voidTy, {i8PtrTy, i64Ty, i64Ty}), + // cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id3CtorRange, + "_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm", + voidTy, {i8PtrTy, i64Ty, i64Ty}), + + // cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id1CtorItem, + "_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm", + voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}), + // cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id2CtorItem, + "_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm", + voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}), + // cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) + SYCLFuncDescriptor( + SYCLFuncDescriptor::FuncId::Id3CtorItem, + "_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm", + voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}), + }; + // clang-format on + + // Declare SYCL functions and add them to the registry. + for (SYCLFuncDescriptor &funcDesc : descriptors) { + funcDesc.declareFunction(module, builder); + registry.emplace(funcDesc.id, funcDesc); + } +}