Skip to content

Commit

Permalink
Create a registry for SYCL functions callable by the compiler (#40)
Browse files Browse the repository at this point in the history
This PR creates the infrastructure required to register callable SYCL functions into a "registry" object.
To use the new infrastructure:
```
// Initialize the SYCL function registry, this will inject function declarations into the module.
const auto &registry = SYCLFuncRegistry::create(module, builder);
...
// Call a SYCL function.
Value func = SYCLFuncDescriptor::call(SYCLFuncDescriptor::FuncId::Id1CtorDefault, args, registry, builder, loc);
```
  • Loading branch information
etiotto committed Sep 6, 2022
1 parent 7ebf3c4 commit 57e81b7
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 0 deletions.
106 changes: 106 additions & 0 deletions mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//===- SYCLFuncRegistry.h - Registry of SYCL Functions ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Declare a registry of SYCL functions callable from the compiler.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H
#define MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H

#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"

namespace mlir {
namespace sycl {
class SYCLFuncRegistry;

/// \class SYCLFuncDescriptor
/// Represents a SYCL function (defined in a registry) that can be called by the
/// compiler.
/// Note: when a new enumerator is added the corresponding SYCLFuncDescriptor
/// needs to be created in SYCLFuncRegistry constructor.
class SYCLFuncDescriptor {
friend class SYCLFuncRegistry;

public:
/// Enumerates SYCL functions.
// clang-format off
enum class FuncId {
// Member functions for the sycl:id<n> class.
Id1CtorDefault, // sycl::id<1>::id()
Id2CtorDefault, // sycl::id<2>::id()
Id3CtorDefault, // sycl::id<3>::id()
Id1CtorSizeT, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type)
Id2CtorSizeT, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type)
Id3CtorSizeT, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type)
Id1CtorRange, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long)
Id2CtorRange, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long)
Id3CtorRange, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long)
Id1CtorItem, // sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long)
Id2CtorItem, // sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long)
Id3CtorItem, // sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long)

// Member functions for ..TODO..
};
// clang-format on

// Call the SYCL constructor identified by \p id with the given \p args.
static Value call(FuncId id, ArrayRef<Value> args,
const SYCLFuncRegistry &registry, OpBuilder &b,
Location loc);

private:
/// Private constructor: only available to 'SYCLFuncRegistry'.
SYCLFuncDescriptor(FuncId id, StringRef name, Type outputTy,
ArrayRef<Type> argTys)
: id(id), name(name), outputTy(outputTy),
argTys(argTys.begin(), argTys.end()) {}

// Inject the declaration for this function into the module.
void declareFunction(ModuleOp &module, OpBuilder &b);

private:
FuncId id; // unique identifier for a SYCL function
StringRef name; // SYCL function name
Type outputTy; // SYCL function output type
SmallVector<Type, 4> argTys; // SYCL function arguments types
FlatSymbolRefAttr funcRef; // Reference to the SYCL function declaration
};

/// \class SYCLFuncRegistry
/// Singleton class representing the set of SYCL functions callable from the
/// compiler.
class SYCLFuncRegistry {
public:
~SYCLFuncRegistry() { instance = nullptr; }

static const SYCLFuncRegistry create(ModuleOp &module, OpBuilder &builder);

const SYCLFuncDescriptor &getFuncDesc(SYCLFuncDescriptor::FuncId id) const {
assert(
(registry.find(id) != registry.end()) &&
"function identified by 'id' not found in the SYCL function registry");
return registry.at(id);
}

private:
SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder);

using Registry = std::map<SYCLFuncDescriptor::FuncId, SYCLFuncDescriptor>;
static SYCLFuncRegistry *instance;
Registry registry;
};

} // namespace sycl
} // namespace mlir

#endif // MLIR_CONVERSION_SYCLTOLLVM_SYCLFUNCREGISTRY_H
1 change: 1 addition & 0 deletions mlir-sycl/lib/Conversion/SYCLToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_conversion_library(MLIRSYCLToLLVM
SYCLFuncRegistry.cpp
SYCLToLLVM.cpp
SYCLToLLVMPass.cpp

Expand Down
159 changes: 159 additions & 0 deletions mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
//===- SYCLFuncRegistry - SYCL functions registry --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implement a registry of SYCL functions callable by the compiler.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "sycl-func-registry"

using namespace mlir;
using namespace mlir::sycl;

// TODO: move in LLVMBuilder class when available.
static FlatSymbolRefAttr getOrInsertFuncDecl(ModuleOp module, OpBuilder &b,
StringRef funcName,
Type resultType,
ArrayRef<Type> argsTypes,
bool isVarArg = false) {
if (!module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(module.getBody());
LLVM::LLVMFunctionType funcType =
LLVM::LLVMFunctionType::get(resultType, argsTypes, isVarArg);
b.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
}
return SymbolRefAttr::get(b.getContext(), funcName);
}

//===----------------------------------------------------------------------===//
// SYCLFuncDescriptor
//===----------------------------------------------------------------------===//

void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) {
// TODO: use LLVMBuilder once available.
funcRef = getOrInsertFuncDecl(module, b, name, outputTy, argTys);
}

Value SYCLFuncDescriptor::call(FuncId id, ArrayRef<Value> args,
const SYCLFuncRegistry &registry, OpBuilder &b,
Location loc) {
SmallVector<Type, 1> funcOutputTys;
const SYCLFuncDescriptor &funcDesc = registry.getFuncDesc(id);
if (!funcDesc.outputTy.isa<LLVM::LLVMVoidType>())
funcOutputTys.emplace_back(funcDesc.outputTy);

// TODO: generate the call via LLVMBuilder here
// LLVMBuilder builder(b, loc);
// return builder.call(funcDesc.funcRef, ArrayRef<Type>(funcOutputsTys), args);
// TODO: we could check here the arguments against the function signature and
// assert if there is a mismatch.
auto callOp = b.create<LLVM::CallOp>(loc, ArrayRef<Type>(funcOutputTys),
funcDesc.funcRef, args);
assert(callOp.getNumResults() == 1 && "expecting a single result");

return callOp.getResult(0);
}

//===----------------------------------------------------------------------===//
// SYCLFuncRegistry
//===----------------------------------------------------------------------===//

SYCLFuncRegistry *SYCLFuncRegistry::instance = nullptr;

const SYCLFuncRegistry SYCLFuncRegistry::create(
ModuleOp &module, OpBuilder &builder) {
if (!instance)
instance = new SYCLFuncRegistry(module, builder);

return *instance;
}

SYCLFuncRegistry::SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder)
: registry() {
MLIRContext *context = module.getContext();
auto voidTy = LLVM::LLVMVoidType::get(context);
auto i8Ty = IntegerType::get(context, 8);
auto i8PtrTy = LLVM::LLVMPointerType::get(i8Ty);
auto i8PtrPtrTy = LLVM::LLVMPointerType::get(i8PtrTy);
auto i64Ty = IntegerType::get(context, 64);
auto i64PtrTy = LLVM::LLVMPointerType::get(i64Ty);

// Construct the SYCL functions descriptors (enum, function name, signature).
// clang-format off
std::vector<SYCLFuncDescriptor> descriptors = {
// cl::sycl::id<1>::id()
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id1CtorDefault,
"_ZN2cl4sycl2idILi1EEC2Ev", voidTy, {i8PtrTy}),
// cl::sycl::id<2>::id()
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id2CtorDefault,
"_ZN2cl4sycl2idILi2EEC2Ev", voidTy, {i8PtrTy}),
// cl::sycl::id<3>::id()
SYCLFuncDescriptor(SYCLFuncDescriptor::FuncId::Id3CtorDefault,
"_ZN2cl4sycl2idILi3EEC2Ev", voidTy, {i8PtrTy}),

// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id1CtorSizeT,
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE",
voidTy, {i8PtrTy, i64Ty}),
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id2CtorSizeT,
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE",
voidTy, {i8PtrTy, i64Ty}),
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id3CtorSizeT,
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE",
voidTy, {i8PtrTy, i64Ty}),

// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id1CtorRange,
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm",
voidTy, {i8PtrTy, i64Ty, i64Ty}),
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id2CtorRange,
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm",
voidTy, {i8PtrTy, i64Ty, i64Ty}),
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id3CtorRange,
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm",
voidTy, {i8PtrTy, i64Ty, i64Ty}),

// cl::sycl::id<1>::id<1>(std::enable_if<(1)==(1), unsigned long>::type, unsigned long, unsigned long)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id1CtorItem,
"_ZN2cl4sycl2idILi1EEC2ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm",
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}),
// cl::sycl::id<2>::id<2>(std::enable_if<(2)==(2), unsigned long>::type, unsigned long, unsigned long)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id2CtorItem,
"_ZN2cl4sycl2idILi2EEC2ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm",
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}),
// cl::sycl::id<3>::id<3>(std::enable_if<(3)==(3), unsigned long>::type, unsigned long, unsigned long)
SYCLFuncDescriptor(
SYCLFuncDescriptor::FuncId::Id3CtorItem,
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm",
voidTy, {i8PtrTy, i64Ty, i64Ty, i64Ty}),
};
// clang-format on

// Declare SYCL functions and add them to the registry.
for (SYCLFuncDescriptor &funcDesc : descriptors) {
funcDesc.declareFunction(module, builder);
registry.emplace(funcDesc.id, funcDesc);
}
}

0 comments on commit 57e81b7

Please sign in to comment.