-
Notifications
You must be signed in to change notification settings - Fork 756
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create a registry for SYCL functions callable by the compiler (#40)
This PR creates the infrastructure required to register callable SYCL functions into a "registry" object. To use the new infrastructure: ``` // Initialize the SYCL function registry, this will inject function declarations into the module. const auto ®istry = SYCLFuncRegistry::create(module, builder); ... // Call a SYCL function. Value func = SYCLFuncDescriptor::call(SYCLFuncDescriptor::FuncId::Id1CtorDefault, args, registry, builder, loc); ```
- Loading branch information
Showing
3 changed files
with
266 additions
and
0 deletions.
There are no files selected for viewing
106 changes: 106 additions & 0 deletions
106
mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
//===- SYCLFuncRegistry.h - Registry of SYCL Functions ----------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Declare a registry of SYCL functions callable from the compiler. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H | ||
#define MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H | ||
|
||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/Types.h" | ||
|
||
namespace mlir { | ||
namespace sycl { | ||
class SYCLFuncRegistry; | ||
|
||
/// \class SYCLFuncDescriptor | ||
/// Represents a SYCL function (defined in a registry) that can be called by the | ||
/// compiler. | ||
/// Note: when a new enumerator is added the corresponding SYCLFuncDescriptor | ||
/// needs to be created in SYCLFuncRegistry constructor. | ||
class SYCLFuncDescriptor { | ||
friend class SYCLFuncRegistry; | ||
|
||
public: | ||
/// Enumerates SYCL functions. | ||
// clang-format off | ||
enum class FuncId { | ||
// Member functions for the sycl:id<n> class. | ||
Id1CtorDefault, // sycl::id<1>::id() | ||
Id2CtorDefault, // sycl::id<2>::id() | ||
Id3CtorDefault, // sycl::id<3>::id() | ||
Id1CtorSizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type) | ||
Id2CtorSizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type) | ||
Id3CtorSizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type) | ||
Id1CtorRange, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) | ||
Id2CtorRange, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) | ||
Id3CtorRange, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) | ||
Id1CtorItem, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) | ||
Id2CtorItem, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) | ||
Id3CtorItem, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) | ||
|
||
// Member functions for ..TODO.. | ||
}; | ||
// clang-format on | ||
|
||
// Call the SYCL constructor identified by \p id with the given \p args. | ||
static Value call(FuncId id, ArrayRef<Value> args, | ||
const SYCLFuncRegistry ®istry, OpBuilder &b, | ||
Location loc); | ||
|
||
private: | ||
/// Private constructor: only available to 'SYCLFuncRegistry'. | ||
SYCLFuncDescriptor(FuncId id, StringRef name, Type outputTy, | ||
ArrayRef<Type> argTys) | ||
: id(id), name(name), outputTy(outputTy), | ||
argTys(argTys.begin(), argTys.end()) {} | ||
|
||
// Inject the declaration for this function into the module. | ||
void declareFunction(ModuleOp &module, OpBuilder &b); | ||
|
||
private: | ||
FuncId id; // unique identifier for a SYCL function | ||
StringRef name; // SYCL function name | ||
Type outputTy; // SYCL function output type | ||
SmallVector<Type, 4> argTys; // SYCL function arguments types | ||
FlatSymbolRefAttr funcRef; // Reference to the SYCL function declaration | ||
}; | ||
|
||
/// \class SYCLFuncRegistry | ||
/// Singleton class representing the set of SYCL functions callable from the | ||
/// compiler. | ||
class SYCLFuncRegistry { | ||
public: | ||
~SYCLFuncRegistry() { instance = nullptr; } | ||
|
||
static const SYCLFuncRegistry create(ModuleOp &module, OpBuilder &builder); | ||
|
||
const SYCLFuncDescriptor &getFuncDesc(SYCLFuncDescriptor::FuncId id) const { | ||
assert( | ||
(registry.find(id) != registry.end()) && | ||
"function identified by 'id' not found in the SYCL function registry"); | ||
return registry.at(id); | ||
} | ||
|
||
private: | ||
SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder); | ||
|
||
using Registry = std::map<SYCLFuncDescriptor::FuncId, SYCLFuncDescriptor>; | ||
static SYCLFuncRegistry *instance; | ||
Registry registry; | ||
}; | ||
|
||
} // namespace sycl | ||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
add_mlir_conversion_library(MLIRSYCLToLLVM | ||
SYCLFuncRegistry.cpp | ||
SYCLToLLVM.cpp | ||
SYCLToLLVMPass.cpp | ||
|
||
|
159 changes: 159 additions & 0 deletions
159
mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
//===- SYCLFuncRegistry - SYCL functions registry --------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Implement a registry of SYCL functions callable by the compiler. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "llvm/Support/Debug.h" | ||
|
||
#define DEBUG_TYPE "sycl-func-registry" | ||
|
||
using namespace mlir; | ||
using namespace mlir::sycl; | ||
|
||
// TODO: move in LLVMBuilder class when available. | ||
static FlatSymbolRefAttr getOrInsertFuncDecl(ModuleOp module, OpBuilder &b, | ||
StringRef funcName, | ||
Type resultType, | ||
ArrayRef<Type> argsTypes, | ||
bool isVarArg = false) { | ||
if (!module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) { | ||
OpBuilder::InsertionGuard guard(b); | ||
b.setInsertionPointToStart(module.getBody()); | ||
LLVM::LLVMFunctionType funcType = | ||
LLVM::LLVMFunctionType::get(resultType, argsTypes, isVarArg); | ||
b.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType); | ||
} | ||
return SymbolRefAttr::get(b.getContext(), funcName); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// SYCLFuncDescriptor | ||
//===----------------------------------------------------------------------===// | ||
|
||
void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) { | ||
// TODO: use LLVMBuilder once available. | ||
funcRef = getOrInsertFuncDecl(module, b, name, outputTy, argTys); | ||
} | ||
|
||
Value SYCLFuncDescriptor::call(FuncId id, ArrayRef<Value> args, | ||
const SYCLFuncRegistry ®istry, OpBuilder &b, | ||
Location loc) { | ||
SmallVector<Type, 1> funcOutputTys; | ||
const SYCLFuncDescriptor &funcDesc = registry.getFuncDesc(id); | ||
if (!funcDesc.outputTy.isa<LLVM::LLVMVoidType>()) | ||
funcOutputTys.emplace_back(funcDesc.outputTy); | ||
|
||
// TODO: generate the call via LLVMBuilder here | ||
// LLVMBuilder builder(b, loc); | ||
// return builder.call(funcDesc.funcRef, ArrayRef<Type>(funcOutputsTys), args); | ||
// TODO: we could check here the arguments against the function signature and | ||
// assert if there is a mismatch. | ||
auto callOp = b.create<LLVM::CallOp>(loc, ArrayRef<Type>(funcOutputTys), | ||
funcDesc.funcRef, args); | ||
assert(callOp.getNumResults() == 1 && "expecting a single result"); | ||
|
||
return callOp.getResult(0); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// SYCLFuncRegistry | ||
//===----------------------------------------------------------------------===// | ||
|
||
SYCLFuncRegistry *SYCLFuncRegistry::instance = nullptr; | ||
|
||
const SYCLFuncRegistry SYCLFuncRegistry::create( | ||
ModuleOp &module, OpBuilder &builder) { | ||
if (!instance) | ||
instance = new SYCLFuncRegistry(module, builder); | ||
|
||
return *instance; | ||
} | ||
|
||
SYCLFuncRegistry::SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder) | ||
: registry() { | ||
MLIRContext *context = module.getContext(); | ||
auto voidTy = LLVM::LLVMVoidType::get(context); | ||
auto i8Ty = IntegerType::get(context, 8); | ||
auto i8PtrTy = LLVM::LLVMPointerType::get(i8Ty); | ||
auto i8PtrPtrTy = LLVM::LLVMPointerType::get(i8PtrTy); | ||
auto i64Ty = IntegerType::get(context, 64); | ||
auto i64PtrTy = LLVM::LLVMPointerType::get(i64Ty); | ||
|
||
// Construct the SYCL functions descriptors (enum, function name, signature). | ||
// clang-format off | ||
std::vector<SYCLFuncDescriptor> descriptors = { | ||
// cl::sycl::id<1>::id() | ||
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1CtorDefault, | ||
"_ZN2cl4sycl2idILi1EEC2Ev", voidTy, {i8PtrTy}), | ||
// cl::sycl::id<2>::id() | ||
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2CtorDefault, | ||
"_ZN2cl4sycl2idILi2EEC2Ev", voidTy, {i8PtrTy}), | ||
// cl::sycl::id<3>::id() | ||
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3CtorDefault, | ||
"_ZN2cl4sycl2idILi3EEC2Ev", voidTy, {i8PtrTy}), | ||
|
||
// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id1CtorSizeT, | ||
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE", | ||
voidTy, {i8PtrTy, i64Ty}), | ||
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id2CtorSizeT, | ||
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE", | ||
voidTy, {i8PtrTy, i64Ty}), | ||
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id3CtorSizeT, | ||
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE", | ||
voidTy, {i8PtrTy, i64Ty}), | ||
|
||
// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id1CtorRange, | ||
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm", | ||
voidTy, {i8PtrTy, i64Ty, i64Ty}), | ||
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id2CtorRange, | ||
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm", | ||
voidTy, {i8PtrTy, i64Ty, i64Ty}), | ||
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id3CtorRange, | ||
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm", | ||
voidTy, {i8PtrTy, i64Ty, i64Ty}), | ||
|
||
// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id1CtorItem, | ||
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm", | ||
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}), | ||
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id2CtorItem, | ||
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm", | ||
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}), | ||
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long) | ||
SYCLFuncDescriptor( | ||
SYCLFuncDescriptor::FuncId::Id3CtorItem, | ||
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm", | ||
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}), | ||
}; | ||
// clang-format on | ||
|
||
// Declare SYCL functions and add them to the registry. | ||
for (SYCLFuncDescriptor &funcDesc : descriptors) { | ||
funcDesc.declareFunction(module, builder); | ||
registry.emplace(funcDesc.id, funcDesc); | ||
} | ||
} |