Skip to content

Commit

Permalink
Add type converters for sycl::group and sycl::nd_item (#54)
Browse files Browse the repository at this point in the history
The runtime class of `sycl::group`:
```
template <int Dimensions = 1> class group {
...
  range<Dimensions> globalRange;
  range<Dimensions> localRange;
  range<Dimensions> groupRange;
  id<Dimensions> index;
...
}
```
The runtime class of `sycl::nd_item`:
```
template <int dimensions = 1> class nd_item {
...
  item<dimensions, true> globalItem;
  item<dimensions, false> localItem;
  group<dimensions> Group;
...
}
```
Example of LLVM IR generated directly from clang:
```
%"class.cl::sycl::group" = type { %"class.cl::sycl::range", %"class.cl::sycl::range", %"class.cl::sycl::range", %"class.cl::sycl::id" }
%"class.cl::sycl::nd_item" = type { %"class.cl::sycl::item", %"class.cl::sycl::item.0", %"class.cl::sycl::group" }
%"class.cl::sycl::item" = type { %"struct.cl::sycl::detail::ItemBase" }
%"class.cl::sycl::item.0" = type { %"struct.cl::sycl::detail::ItemBase.1" }
%"struct.cl::sycl::detail::ItemBase" = type { %"class.cl::sycl::range", %"class.cl::sycl::id", %"class.cl::sycl::id" }
%"struct.cl::sycl::detail::ItemBase.1" = type { %"class.cl::sycl::range", %"class.cl::sycl::id" }
```

Signed-off-by: Tsang, Whitney <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang authored and etiotto committed Sep 6, 2022
1 parent 133be13 commit 9656ebf
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 55 deletions.
114 changes: 64 additions & 50 deletions mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,52 +74,6 @@ static Optional<Type> getArrayTy(MLIRContext &context, unsigned dimNum,
// Type conversion
//===----------------------------------------------------------------------===//

/// Converts SYCL array type to LLVM type.
static Optional<Type> convertArrayType(sycl::ArrayType type,
LLVMTypeConverter &converter) {
assert(type.getBody().size() == 1 &&
"Expecting SYCL array body to have size 1");
assert(type.getBody()[0].isa<MemRefType>() &&
"Expecting SYCL array body entry to be MemRefType");
assert(type.getBody()[0].cast<MemRefType>().getElementType() ==
converter.getIndexType() &&
"Expecting SYCL array body entry element type to be the index type");
return getArrayTy(converter.getContext(), type.getDimension(),
converter.getIndexType());
}

/// Converts SYCL range or id type to LLVM type, given \p dimNum - number of
/// dimensions, \p name - the expected LLVM type name, \p converter - LLVM type
/// converter.
static Optional<Type> convertRangeOrIDTy(unsigned dimNum, StringRef name,
LLVMTypeConverter &converter) {
auto convertedTy = LLVM::LLVMStructType::getIdentified(
&converter.getContext(), name.str() + "." + std::to_string(dimNum));
if (!convertedTy.isInitialized()) {
auto arrayTy =
getArrayTy(converter.getContext(), dimNum, converter.getIndexType());
if (!arrayTy.hasValue())
return llvm::None;
if (failed(convertedTy.setBody(arrayTy.getValue(), /*isPacked=*/false)))
return llvm::None;
}
return convertedTy;
}

/// Converts SYCL id type to LLVM type.
static Optional<Type> convertIDType(sycl::IDType type,
LLVMTypeConverter &converter) {
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::id",
converter);
}

/// Converts SYCL range type to LLVM type.
static Optional<Type> convertRangeType(sycl::RangeType type,
LLVMTypeConverter &converter) {
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::range",
converter);
}

/// Create a LLVM struct type with name \p name, and the converted \p body as
/// the body.
static Optional<Type> convertBodyType(StringRef name,
Expand Down Expand Up @@ -172,6 +126,53 @@ static Optional<Type> convertAccessorType(sycl::AccessorType type,
return convertedTy;
}

/// Converts SYCL array type to LLVM type.
static Optional<Type> convertArrayType(sycl::ArrayType type,
LLVMTypeConverter &converter) {
assert(type.getBody().size() == 1 &&
"Expecting SYCL array body to have size 1");
assert(type.getBody()[0].isa<MemRefType>() &&
"Expecting SYCL array body entry to be MemRefType");
assert(type.getBody()[0].cast<MemRefType>().getElementType() ==
converter.getIndexType() &&
"Expecting SYCL array body entry element type to be the index type");
return getArrayTy(converter.getContext(), type.getDimension(),
converter.getIndexType());
}

/// Converts SYCL group type to LLVM type.
static Optional<Type> convertGroupType(sycl::GroupType type,
LLVMTypeConverter &converter) {
return convertBodyType("class.cl::sycl::group." +
std::to_string(type.getDimension()),
type.getBody(), converter);
}

/// Converts SYCL range or id type to LLVM type, given \p dimNum - number of
/// dimensions, \p name - the expected LLVM type name, \p converter - LLVM type
/// converter.
static Optional<Type> convertRangeOrIDTy(unsigned dimNum, StringRef name,
LLVMTypeConverter &converter) {
auto convertedTy = LLVM::LLVMStructType::getIdentified(
&converter.getContext(), name.str() + "." + std::to_string(dimNum));
if (!convertedTy.isInitialized()) {
auto arrayTy =
getArrayTy(converter.getContext(), dimNum, converter.getIndexType());
if (!arrayTy.hasValue())
return llvm::None;
if (failed(convertedTy.setBody(arrayTy.getValue(), /*isPacked=*/false)))
return llvm::None;
}
return convertedTy;
}

/// Converts SYCL id type to LLVM type.
static Optional<Type> convertIDType(sycl::IDType type,
LLVMTypeConverter &converter) {
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::id",
converter);
}

/// Converts SYCL item base type to LLVM type.
static Optional<Type> convertItemBaseType(sycl::ItemBaseType type,
LLVMTypeConverter &converter) {
Expand All @@ -190,6 +191,21 @@ static Optional<Type> convertItemType(sycl::ItemType type,
type.getBody(), converter);
}

/// Converts SYCL nd item type to LLVM type.
static Optional<Type> convertNdItemType(sycl::NdItemType type,
LLVMTypeConverter &converter) {
return convertBodyType("class.cl::sycl::nd_item." +
std::to_string(type.getDimension()),
type.getBody(), converter);
}

/// Converts SYCL range type to LLVM type.
static Optional<Type> convertRangeType(sycl::RangeType type,
LLVMTypeConverter &converter) {
return convertRangeOrIDTy(type.getDimension(), "class.cl::sycl::range",
converter);
}

//===----------------------------------------------------------------------===//
// ConstructorPattern - Converts `sycl.constructor` to LLVM.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -263,8 +279,7 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
return convertArrayType(type, typeConverter);
});
typeConverter.addConversion([&](sycl::GroupType type) {
llvm_unreachable("SYCLToLLVM - sycl::GroupType not handle (yet)");
return llvm::None;
return convertGroupType(type, typeConverter);
});
typeConverter.addConversion(
[&](sycl::IDType type) { return convertIDType(type, typeConverter); });
Expand All @@ -275,8 +290,7 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
return convertItemType(type, typeConverter);
});
typeConverter.addConversion([&](sycl::NdItemType type) {
llvm_unreachable("SYCLToLLVM - sycl::NdItemType not handle (yet)");
return llvm::None;
return convertNdItemType(type, typeConverter);
});
typeConverter.addConversion([&](sycl::RangeType type) {
return convertRangeType(type, typeConverter);
Expand Down
22 changes: 17 additions & 5 deletions mlir-sycl/test/Conversion/SYCLToLLVM/sycl-types-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
// CHECK: llvm.func @test_accessorImplDevice(%arg0: !llvm.[[ACCESSORIMPLDEVICE_1:struct<"class.cl::sycl::detail::AccessorImplDevice.*", \(]][[ID_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_accessor.1(%arg0: !llvm.[[ACCESSOR_1:struct<"class.cl::sycl::accessor.*", \(]][[ACCESSORIMPLDEVICE_1]][[ID_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]], struct<(ptr<i32, 1>)>)>)
// CHECK: llvm.func @test_accessor.2(%arg0: !llvm.[[ACCESSOR_2:struct<"class.cl::sycl::accessor.*", \(]][[ACCESSORIMPLDEVICE_2:struct<"class.cl::sycl::detail::AccessorImplDevice.*", \(]][[ID_2:struct<"class.cl::sycl::id.*", \(]][[ARRAY_2]][[SUFFIX]], [[RANGE_2]][[ARRAY_2]][[SUFFIX]], [[RANGE_2]][[ARRAY_2]][[SUFFIX]][[SUFFIX]], struct<(ptr<i64, 1>)>)>)
// CHECK: llvm.func @test_item_base.true(%arg0: !llvm.[[ITEM_BASE_1_TRUE:struct<"class.cl::sycl::detail::ItemBase.1.true", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_item_base.false(%arg0: !llvm.[[ITEM_BASE_2_FALSE:struct<"class.cl::sycl::detail::ItemBase.2.false", \(]][[RANGE_2]][[ARRAY_2]][[SUFFIX]], [[ID_2]][[ARRAY_2]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_item(%arg0: !llvm.[[ITEM_1_TRUE:struct<"class.cl::sycl::item.1.true", \(]][[ITEM_BASE_1_TRUE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_item_base.true(%arg0: !llvm.[[ITEM_BASE_1_TRUE:struct<"class.cl::sycl::detail::ItemBase.*", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_item_base.false(%arg0: !llvm.[[ITEM_BASE_1_FALSE:struct<"class.cl::sycl::detail::ItemBase.*", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_item.true(%arg0: !llvm.[[ITEM_1_TRUE:struct<"class.cl::sycl::item.*", \(]][[ITEM_BASE_1_TRUE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_item.false(%arg0: !llvm.[[ITEM_1_FALSE:struct<"class.cl::sycl::item.*", \(]][[ITEM_BASE_1_FALSE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_group(%arg0: !llvm.[[GROUP_1:struct<"class.cl::sycl::group.*", \(]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]])
// CHECK: llvm.func @test_nd_item(%arg0: !llvm.[[ND_ITEM_1:struct<"class.cl::sycl::nd_item.*", \(]][[ITEM_1_TRUE]][[ITEM_BASE_1_TRUE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]], [[ITEM_1_FALSE]][[ITEM_BASE_1_FALSE]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]], [[GROUP_1]][[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[RANGE_1]][[ARRAY_1]][[SUFFIX]], [[ID_1]][[ARRAY_1]][[SUFFIX]][[SUFFIX]][[SUFFIX]])

module {
func.func @test_array.1(%arg0: !sycl.array<[1], (memref<1xi64>)>) {
Expand Down Expand Up @@ -40,10 +43,19 @@ module {
func.func @test_item_base.true(%arg0: !sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>) {
return
}
func.func @test_item_base.false(%arg0: !sycl.item_base<[2, false], (!sycl.range<2>, !sycl.id<2>)>) {
func.func @test_item_base.false(%arg0: !sycl.item_base<[1, false], (!sycl.range<1>, !sycl.id<1>)>) {
return
}
func.func @test_item(%arg0: !sycl.item<[1, true], (!sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>) {
func.func @test_item.true(%arg0: !sycl.item<[1, true], (!sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>) {
return
}
func.func @test_item.false(%arg0: !sycl.item<[1, false], (!sycl.item_base<[1, false], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>) {
return
}
func.func @test_group(%arg0: !sycl.group<[1], (!sycl.range<1>, !sycl.range<1>, !sycl.range<1>, !sycl.id<1>)>) {
return
}
func.func @test_nd_item(%arg0: !sycl.nd_item<[1], (!sycl.item<[1, true], (!sycl.item_base<[1, true], (!sycl.range<1>, !sycl.id<1>, !sycl.id<1>)>)>, !sycl.item<[1, false], (!sycl.item_base<[1, false], (!sycl.range<1>, !sycl.id<1>)>)>, !sycl.group<[1], (!sycl.range<1>, !sycl.range<1>, !sycl.range<1>, !sycl.id<1>)>)>) {
return
}
}

0 comments on commit 9656ebf

Please sign in to comment.