Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][LLVM] Remove typed pointers from the LLVM dialect #71285

Merged
merged 5 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ class GEPIndicesAdaptor {
/// global and use it to compute the address of the first character in the
/// string (operations inserted at the builder insertion point).
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
StringRef value, Linkage linkage,
bool useOpaquePointers = true);
StringRef value, Linkage linkage);

/// LLVM requires some operations to be inside of a Module operation. This
/// function confirms that the Operation has the desired properties.
Expand Down
28 changes: 14 additions & 14 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -469,16 +469,16 @@ def LLVM_ThreadlocalAddressOp : LLVM_OneResultIntrOp<"threadlocal.address", [],

def LLVM_CoroIdOp : LLVM_IntrOp<"coro.id", [], [], [], 1> {
let arguments = (ins I32:$align,
LLVM_i8Ptr:$promise,
LLVM_i8Ptr:$coroaddr,
LLVM_i8Ptr:$fnaddrs);
LLVM_AnyPointer:$promise,
LLVM_AnyPointer:$coroaddr,
LLVM_AnyPointer:$fnaddrs);
let assemblyFormat = "$align `,` $promise `,` $coroaddr `,` $fnaddrs"
" attr-dict `:` functional-type(operands, results)";
}

def LLVM_CoroBeginOp : LLVM_IntrOp<"coro.begin", [], [], [], 1> {
let arguments = (ins LLVM_TokenType:$token,
LLVM_i8Ptr:$mem);
LLVM_AnyPointer:$mem);
let assemblyFormat = "$token `,` $mem attr-dict `:` functional-type(operands, results)";
}

Expand All @@ -491,7 +491,7 @@ def LLVM_CoroAlignOp : LLVM_IntrOp<"coro.align", [0], [], [], 1> {
}

def LLVM_CoroSaveOp : LLVM_IntrOp<"coro.save", [], [], [], 1> {
let arguments = (ins LLVM_i8Ptr:$handle);
let arguments = (ins LLVM_AnyPointer:$handle);
let assemblyFormat = "$handle attr-dict `:` functional-type(operands, results)";
}

Expand All @@ -502,20 +502,20 @@ def LLVM_CoroSuspendOp : LLVM_IntrOp<"coro.suspend", [], [], [], 1> {
}

def LLVM_CoroEndOp : LLVM_IntrOp<"coro.end", [], [], [], 1> {
let arguments = (ins LLVM_i8Ptr:$handle,
let arguments = (ins LLVM_AnyPointer:$handle,
I1:$unwind,
LLVM_TokenType:$retvals);
let assemblyFormat = "$handle `,` $unwind `,` $retvals attr-dict `:` functional-type(operands, results)";
}

def LLVM_CoroFreeOp : LLVM_IntrOp<"coro.free", [], [], [], 1> {
let arguments = (ins LLVM_TokenType:$id,
LLVM_i8Ptr:$handle);
LLVM_AnyPointer:$handle);
let assemblyFormat = "$id `,` $handle attr-dict `:` functional-type(operands, results)";
}

def LLVM_CoroResumeOp : LLVM_IntrOp<"coro.resume", [], [], [], 0> {
let arguments = (ins LLVM_i8Ptr:$handle);
let arguments = (ins LLVM_AnyPointer:$handle);
let assemblyFormat = "$handle attr-dict `:` qualified(type($handle))";
}

Expand Down Expand Up @@ -591,19 +591,19 @@ def LLVM_DbgLabelOp : LLVM_IntrOp<"dbg.label", [], [], [], 0> {
//

def LLVM_VaStartOp : LLVM_ZeroResultIntrOp<"vastart">,
Arguments<(ins LLVM_i8Ptr:$arg_list)> {
Arguments<(ins LLVM_AnyPointer:$arg_list)> {
let assemblyFormat = "$arg_list attr-dict `:` qualified(type($arg_list))";
let summary = "Initializes `arg_list` for subsequent variadic argument extractions.";
}

def LLVM_VaCopyOp : LLVM_ZeroResultIntrOp<"vacopy">,
Arguments<(ins LLVM_i8Ptr:$dest_list, LLVM_i8Ptr:$src_list)> {
Arguments<(ins LLVM_AnyPointer:$dest_list, LLVM_AnyPointer:$src_list)> {
let assemblyFormat = "$src_list `to` $dest_list attr-dict `:` type(operands)";
let summary = "Copies the current argument position from `src_list` to `dest_list`.";
}

def LLVM_VaEndOp : LLVM_ZeroResultIntrOp<"vaend">,
Arguments<(ins LLVM_i8Ptr:$arg_list)> {
Arguments<(ins LLVM_AnyPointer:$arg_list)> {
let assemblyFormat = "$arg_list attr-dict `:` qualified(type($arg_list))";
let summary = "Destroys `arg_list`, which has been initialized by `intr.vastart` or `intr.vacopy`.";
}
Expand All @@ -613,7 +613,7 @@ def LLVM_VaEndOp : LLVM_ZeroResultIntrOp<"vaend">,
//

def LLVM_EhTypeidForOp : LLVM_OneResultIntrOp<"eh.typeid.for"> {
let arguments = (ins LLVM_i8Ptr:$type_info);
let arguments = (ins LLVM_AnyPointer:$type_info);
let assemblyFormat = "$type_info attr-dict `:` functional-type(operands, results)";
}

Expand Down Expand Up @@ -927,12 +927,12 @@ def LLVM_PtrAnnotation
: LLVM_OneResultIntrOp<"ptr.annotation", [0], [2],
[AllTypesMatch<["res", "ptr"]>,
AllTypesMatch<["annotation", "fileName", "attr"]>]> {
let arguments = (ins LLVM_PointerTo<AnySignlessInteger>:$ptr,
let arguments = (ins LLVM_AnyPointer:$ptr,
LLVM_AnyPointer:$annotation,
LLVM_AnyPointer:$fileName,
I32:$line,
LLVM_AnyPointer:$attr);
let results = (outs LLVM_PointerTo<AnySignlessInteger>:$res);
let results = (outs LLVM_AnyPointer:$res);
}

def LLVM_Annotation
Expand Down
35 changes: 5 additions & 30 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,43 +55,18 @@ def LLVM_AnyFloat : Type<
def LLVM_AnyPointer : Type<CPred<"::llvm::isa<::mlir::LLVM::LLVMPointerType>($_self)">,
"LLVM pointer type", "::mlir::LLVM::LLVMPointerType">;

def LLVM_OpaquePointer : Type<
And<[LLVM_AnyPointer.predicate,
CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).isOpaque()">]>,
"LLVM opaque pointer", "::mlir::LLVM::LLVMPointerType">;

// Type constraint accepting LLVM pointer type with an additional constraint
// on the element type.
class LLVM_PointerTo<Type pointee> : Type<
And<[LLVM_AnyPointer.predicate,
Or<[LLVM_OpaquePointer.predicate,
SubstLeaves<
"$_self",
"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).getElementType()",
pointee.predicate>]>]>,
"LLVM pointer to " # pointee.summary, "::mlir::LLVM::LLVMPointerType">;

// Opaque pointer in a given address space.
class LLVM_OpaquePointerInAddressSpace<int addressSpace> : Type<
And<[LLVM_OpaquePointer.predicate,
CPred<
"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).getAddressSpace() == "
# addressSpace>]>,
class LLVM_PointerInAddressSpace<int addressSpace> : Type<
And<[LLVM_AnyPointer.predicate,
CPred<
"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).getAddressSpace() == "
# addressSpace>]>,
"Opaque LLVM pointer in address space " # addressSpace,
"::mlir::LLVM::LLVMPointerType"> {
let builderCall = "$_builder.getType<::mlir::LLVM::LLVMPointerType>("
# addressSpace # ")";
}

// Type constraints accepting LLVM pointer type to integer of a specific width.
class LLVM_IntPtrBase<int width, int addressSpace = 0> : Type<
And<[LLVM_PointerTo<I<width>>.predicate,
CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self).getAddressSpace()"
" == " # addressSpace>]>,
"LLVM pointer to " # I<width>.summary>;

def LLVM_i8Ptr : LLVM_IntPtrBase<8>;

// Type constraint accepting any LLVM structure type.
def LLVM_AnyStruct : Type<CPred<"::llvm::isa<::mlir::LLVM::LLVMStructType>($_self)">,
"LLVM structure type">;
Expand Down
56 changes: 14 additions & 42 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,14 @@ def LLVM_AllocaOp : LLVM_Op<"alloca",
LLVM_MemOpPatterns {
let arguments = (ins AnyInteger:$arraySize,
OptionalAttr<I64Attr>:$alignment,
OptionalAttr<TypeAttr>:$elem_type,
TypeAttr:$elem_type,
UnitAttr:$inalloca);
let results = (outs Res<LLVM_AnyPointer, "",
[MemAlloc<AutomaticAllocationScopeResource>]>:$res);
string llvmInstName = "Alloca";
string llvmBuilder = [{
auto addrSpace = $_resultType->getPointerAddressSpace();
llvm::Type *elementType = moduleTranslation.convertType(
$elem_type ? *$elem_type
: ::llvm::cast<LLVMPointerType>(op.getType()).getElementType());
llvm::Type *elementType = moduleTranslation.convertType($elem_type);
auto *inst = builder.CreateAlloca(elementType, addrSpace, $arraySize);
}] # setAlignmentCode # [{
inst->setUsedWithInAlloca($inalloca);
Expand All @@ -207,31 +205,16 @@ def LLVM_AllocaOp : LLVM_Op<"alloca",
$res = $_builder.create<LLVM::AllocaOp>(
$_location, $_resultType, $arraySize,
alignment == 0 ? IntegerAttr() : $_builder.getI64IntegerAttr(alignment),
TypeAttr::get(allocatedType), allocaInst->isUsedWithInAlloca());
allocatedType, allocaInst->isUsedWithInAlloca());
}];
let builders = [
DeprecatedOpBuilder<"the usage of typed pointers is deprecated",
(ins "Type":$resultType, "Value":$arraySize,
"unsigned":$alignment),
[{
assert(!::llvm::cast<LLVMPointerType>(resultType).isOpaque() &&
"pass the allocated type explicitly if opaque pointers are used");
if (alignment == 0)
return build($_builder, $_state, resultType, arraySize, IntegerAttr(),
TypeAttr(), false);
build($_builder, $_state, resultType, arraySize,
$_builder.getI64IntegerAttr(alignment), TypeAttr(), false);
}]>,
OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$arraySize,
CArg<"unsigned", "0">:$alignment),
[{
TypeAttr elemTypeAttr =
::llvm::cast<LLVMPointerType>(resultType).isOpaque() ?
TypeAttr::get(elementType) : TypeAttr();
build($_builder, $_state, resultType, arraySize,
alignment == 0 ? IntegerAttr()
: $_builder.getI64IntegerAttr(alignment),
elemTypeAttr, false);
elementType, false);

}]>
];
Expand All @@ -247,7 +230,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$dynamicIndices,
DenseI32ArrayAttr:$rawConstantIndices,
OptionalAttr<TypeAttr>:$elem_type,
TypeAttr:$elem_type,
UnitAttr:$inbounds);
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
let skipDefaultBuilders = 1;
Expand Down Expand Up @@ -282,14 +265,6 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr,
"ValueRange":$indices, CArg<"bool", "false">:$inbounds,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
DeprecatedOpBuilder<"the usage of typed pointers is deprecated",
(ins "Type":$resultType, "Value":$basePtr,
"ValueRange":$indices, CArg<"bool", "false">:$inbounds,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
DeprecatedOpBuilder<"the usage of typed pointers is deprecated",
(ins "Type":$resultType, "Value":$basePtr,
"ArrayRef<GEPArg>":$indices, CArg<"bool", "false">:$inbounds,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr,
"ArrayRef<GEPArg>":$indices, CArg<"bool", "false">:$inbounds,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
Expand All @@ -313,7 +288,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
let assemblyFormat = [{
(`inbounds` $inbounds^)?
$base `[` custom<GEPIndices>($dynamicIndices, $rawConstantIndices) `]` attr-dict
`:` functional-type(operands, results) (`,` $elem_type^)?
`:` functional-type(operands, results) `,` $elem_type
}];

let extraClassDeclaration = [{
Expand All @@ -332,7 +307,7 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
dag args = (ins LLVM_PointerTo<LLVM_LoadableType>:$addr,
dag args = (ins LLVM_AnyPointer:$addr,
OptionalAttr<I64Attr>:$alignment,
UnitAttr:$volatile_,
UnitAttr:$nontemporal,
Expand Down Expand Up @@ -370,7 +345,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
let assemblyFormat = [{
(`volatile` $volatile_^)? $addr
(`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
attr-dict `:` custom<LoadType>(type($addr), type($res))
attr-dict `:` qualified(type($addr)) `->` type($res)

}];
string llvmBuilder = [{
auto *inst = builder.CreateLoad($_resultType, $addr, $volatile_);
Expand All @@ -391,9 +367,6 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
getLLVMSyncScope(loadInst));
}];
let builders = [
DeprecatedOpBuilder<"the usage of typed pointers is deprecated",
(ins "Value":$addr, CArg<"unsigned", "0">:$alignment,
CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal)>,
OpBuilder<(ins "Type":$type, "Value":$addr,
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
CArg<"bool", "false">:$isNonTemporal,
Expand All @@ -408,7 +381,7 @@ def LLVM_StoreOp : LLVM_MemAccessOpBase<"store",
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
dag args = (ins LLVM_LoadableType:$value,
LLVM_PointerTo<LLVM_LoadableType>:$addr,
LLVM_AnyPointer:$addr,
OptionalAttr<I64Attr>:$alignment,
UnitAttr:$volatile_,
UnitAttr:$nontemporal,
Expand Down Expand Up @@ -445,7 +418,7 @@ def LLVM_StoreOp : LLVM_MemAccessOpBase<"store",
let assemblyFormat = [{
(`volatile` $volatile_^)? $value `,` $addr
(`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
attr-dict `:` custom<StoreType>(type($value), type($addr))
attr-dict `:` type($value) `,` qualified(type($addr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need qualified here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without qualified, this just prints <1>, instead of !llvm.ptr<1>. Tbh, I do not fully understand the logic behind this, but I've observed similar issues in other contexts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ftynse are you fine with this change or should we get to the bottom of this issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to understand better why and when is this necessary (the feature looks poorly documented), but this should not block this from landing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behaviour was introduced by ee09087. As far as I can tell it essentially always omits the dialect prefix if the attribute/type class being parsed is known. Since in this Op it is known to be a LLVMPointerType it'll elide the !llvm.ptr by default when printing.
63f0c00 then apparently introduced the qualified directive to be able to restore the behaviour to before where the dialect prefix is always printed.

Personally I think this is non-intuative and that the default should be printing the full type by default instead, but that is another discussion entirely. Using qualified seems to be the intended way to fully qualify the type in the assembly format.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is WAI: qualified is a poor default IMO.
For example why would !llvm be a better choice here? And even if it is, why shouldn't we express the choice?
Should we even print the type when the address space is the default one? Lot of possible simplifications...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that not only !llvm is omitted, but the ptr prefix as well. It therefore prints just <1> when using an address space rather than either ptr<1> or !llvm.ptr<1>. This IMO makes it non-obvious what the type is, nor what the meaning of the <1> is supposed to be.

Copy link
Collaborator

@joker-eph joker-eph Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course! But the answer is contextual...
For example it's not obvious to me that qualified is just the best answer here, what about the following for example?

let assemblyFormat = ` ....    attr-dict `:` type($value) `,` `ptr` `` type($addr)`

I see the design of the assembly format as... a design! It needs to be intentional, and from this point of view we need the most basic blocks to construct a nice format.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason I added qualified in this revision is to ensure as few breaking changes as possible, considering that the pointer type changes already forced me to change hundreds of lines.

I agree that qualified might not be the best idea here, as the "llvm." part brings no benefit. The format you propose would in my opinion be a much better fit. The question remains if we should always print the address space or not, and if we should even print the pointer type if the address space is zero.

Note that we should also consider all other operations where I used qualified in this discussion, i.e., LLVM_StoreOp, LLVM_AtomicRMWOp, LLVM_AtomicCmpXchgOp, NVVM_MBarrierArriveSharedOp.

Copy link
Collaborator

@joker-eph joker-eph Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear I wasn't criticizing this PR: I was providing a rational for "building new dialects".

}];
string llvmBuilder = [{
auto *inst = builder.CreateStore($value, $addr, $volatile_);
Expand Down Expand Up @@ -651,8 +624,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
CArg<"ValueRange", "{}">:$args)>,
OpBuilder<(ins "Value":$callee, "ValueRange":$args)>
CArg<"ValueRange", "{}">:$args)>
];
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
Expand Down Expand Up @@ -1636,7 +1608,7 @@ def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
TypesMatchWith<"result #0 and operand #1 have the same type",
"val", "res", "$_self">]> {
dag args = (ins AtomicBinOp:$bin_op,
LLVM_PointerTo<LLVM_AtomicRMWType>:$ptr,
LLVM_AnyPointer:$ptr,
LLVM_AtomicRMWType:$val, AtomicOrdering:$ordering,
OptionalAttr<StrAttr>:$syncscope,
OptionalAttr<I64Attr>:$alignment,
Expand Down Expand Up @@ -1687,7 +1659,7 @@ def LLVM_AtomicCmpXchgOp : LLVM_MemAccessOpBase<"cmpxchg", [
TypesMatchWith<"result #0 has an LLVM struct type consisting of "
"the type of operand #2 and a bool", "val", "res",
"getValAndBoolStructType($_self)">]> {
dag args = (ins LLVM_PointerTo<LLVM_AtomicCmpXchgType>:$ptr,
dag args = (ins LLVM_AnyPointer:$ptr,
LLVM_AtomicCmpXchgType:$cmp, LLVM_AtomicCmpXchgType:$val,
AtomicOrdering:$success_ordering,
AtomicOrdering:$failure_ordering,
Expand Down
21 changes: 4 additions & 17 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -137,30 +137,17 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
```
}];

let parameters = (ins DefaultValuedParameter<"Type", "Type()">:$elementType,
DefaultValuedParameter<"unsigned", "0">:$addressSpace);
let parameters = (ins DefaultValuedParameter<"unsigned", "0">:$addressSpace);
let assemblyFormat = [{
(`<` custom<Pointer>($elementType, $addressSpace)^ `>`)?
(`<` $addressSpace^ `>`)?
}];

let genVerifyDecl = 1;

let skipDefaultBuilders = 1;
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType,
CArg<"unsigned", "0">:$addressSpace)>,
TypeBuilder<(ins CArg<"unsigned", "0">:$addressSpace), [{
return $_get($_ctxt, Type(), addressSpace);
return $_get($_ctxt, addressSpace);
}]>
];

let extraClassDeclaration = [{
/// Returns `true` if this type is the opaque pointer type, i.e., it has no
/// pointed-to type.
bool isOpaque() const { return !getElementType(); }

/// Checks if the given type can have a pointer type pointing to it.
static bool isValidElementType(Type type);
}];
}

//===----------------------------------------------------------------------===//
Expand Down
Loading