Skip to content

Commit

Permalink
[SymbolRefAttr] Revise SymbolRefAttr to hold a StringAttr.
Browse files Browse the repository at this point in the history
SymbolRefAttr is fundamentally a base string plus a sequence
of nested references.  Instead of storing the string data as
a copies StringRef, store it as an already-uniqued StringAttr.

This makes a lot of things simpler and more efficient because:
1) references to the symbol are already stored as StringAttr's:
   there is no need to copy the string data into MLIRContext
   multiple times.
2) This allows pointer comparisons instead of string
   comparisons (or redundant uniquing) within SymbolTable.cpp.
3) This allows SymbolTable to hold a DenseMap instead of a
   StringMap (which again copies the string data and slows
   lookup).

This is a moderately invasive patch, so I kept a lot of
compatibility APIs around.  It would be nice to explore changing
getName() to return a StringAttr for example (right now you have
to use getNameAttr()), and eliminate things like the StringRef
version of getSymbol.

Differential Revision: https://reviews.llvm.org/D108899
  • Loading branch information
lattner committed Aug 30, 2021
1 parent 3383ec5 commit 41d4aa7
Show file tree
Hide file tree
Showing 28 changed files with 288 additions and 188 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/GPU/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,10 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
unsigned getNumKernelOperands();

/// The name of the kernel's containing module.
StringRef getKernelModuleName();
StringAttr getKernelModuleName();

/// The name of the kernel.
StringRef getKernelName();
StringAttr getKernelName();

/// The i-th operand passed to the kernel function.
Value getKernelOperand(unsigned i);
Expand Down
11 changes: 9 additions & 2 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,16 @@ class Builder {
StringAttr getStringAttr(const Twine &bytes);
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
FlatSymbolRefAttr getSymbolRefAttr(StringRef value);
SymbolRefAttr getSymbolRefAttr(StringRef value,
FlatSymbolRefAttr getSymbolRefAttr(StringAttr value);
SymbolRefAttr getSymbolRefAttr(StringAttr value,
ArrayRef<FlatSymbolRefAttr> nestedReferences);
SymbolRefAttr getSymbolRefAttr(StringRef value,
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
return getSymbolRefAttr(getStringAttr(value), nestedReferences);
}
FlatSymbolRefAttr getSymbolRefAttr(StringRef value) {
return getSymbolRefAttr(getStringAttr(value));
}

// Returns a 0-valued attribute of the given `type`. This function only
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
Expand Down
54 changes: 38 additions & 16 deletions mlir/include/mlir/IR/BuiltinAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,21 @@ class ShapedType;
//===----------------------------------------------------------------------===//

namespace detail {
template <typename T> class ElementsAttrIterator;
template <typename T> class ElementsAttrRange;
template <typename T>
class ElementsAttrIterator;
template <typename T>
class ElementsAttrRange;
} // namespace detail

/// A base attribute that represents a reference to a static shaped tensor or
/// vector constant.
class ElementsAttr : public Attribute {
public:
using Attribute::Attribute;
template <typename T> using iterator = detail::ElementsAttrIterator<T>;
template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
template <typename T>
using iterator = detail::ElementsAttrIterator<T>;
template <typename T>
using iterator_range = detail::ElementsAttrRange<T>;

/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
/// with static shape.
Expand All @@ -52,14 +56,16 @@ class ElementsAttr : public Attribute {

/// Return the value of type 'T' at the given index, where 'T' corresponds to
/// an Attribute type.
template <typename T> T getValue(ArrayRef<uint64_t> index) const {
template <typename T>
T getValue(ArrayRef<uint64_t> index) const {
return getValue(index).template cast<T>();
}

/// Return the elements of this attribute as a value of type 'T'. Note:
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
/// iteration.
template <typename T> iterator_range<T> getValues() const;
template <typename T>
iterator_range<T> getValues() const;

/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const;
Expand Down Expand Up @@ -139,7 +145,8 @@ class DenseElementIndexedIteratorImpl
};

/// Type trait detector that checks if a given type T is a complex type.
template <typename T> struct is_complex_t : public std::false_type {};
template <typename T>
struct is_complex_t : public std::false_type {};
template <typename T>
struct is_complex_t<std::complex<T>> : public std::true_type {};
} // namespace detail
Expand All @@ -154,7 +161,8 @@ class DenseElementsAttr : public ElementsAttr {
/// floating point type that can be used to access the underlying element
/// types of a DenseElementsAttr.
// TODO: Use std::disjunction when C++17 is supported.
template <typename T> struct is_valid_cpp_fp_type {
template <typename T>
struct is_valid_cpp_fp_type {
/// The type is a valid floating point type if it is a builtin floating
/// point type, or is a potentially user defined floating point type. The
/// latter allows for supporting users that have custom types defined for
Expand Down Expand Up @@ -423,7 +431,8 @@ class DenseElementsAttr : public ElementsAttr {
Attribute getValue(ArrayRef<uint64_t> index) const {
return getValue<Attribute>(index);
}
template <typename T> T getValue(ArrayRef<uint64_t> index) const {
template <typename T>
T getValue(ArrayRef<uint64_t> index) const {
// Skip to the element corresponding to the flattened index.
return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
}
Expand Down Expand Up @@ -680,8 +689,15 @@ class FlatSymbolRefAttr : public SymbolRefAttr {
return SymbolRefAttr::get(ctx, value);
}

static FlatSymbolRefAttr get(StringAttr value) {
return SymbolRefAttr::get(value);
}

/// Returns the name of the held symbol reference as a StringAttr.
StringAttr getAttr() const { return getRootReference(); }

/// Returns the name of the held symbol reference.
StringRef getValue() const { return getRootReference(); }
StringRef getValue() const { return getAttr().getValue(); }

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Attribute attr) {
Expand Down Expand Up @@ -845,22 +861,28 @@ class ElementsAttrIterator
}

/// Utility functors used to generically implement the iterators methods.
template <typename ItT> struct PlusAssign {
template <typename ItT>
struct PlusAssign {
void operator()(ItT &it, ptrdiff_t offset) { it += offset; }
};
template <typename ItT> struct Minus {
template <typename ItT>
struct Minus {
ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
};
template <typename ItT> struct MinusAssign {
template <typename ItT>
struct MinusAssign {
void operator()(ItT &it, ptrdiff_t offset) { it -= offset; }
};
template <typename ItT> struct Dereference {
template <typename ItT>
struct Dereference {
T operator()(ItT &it) { return *it; }
};
template <typename ItT> struct ConstructIter {
template <typename ItT>
struct ConstructIter {
void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
};
template <typename ItT> struct DestructIter {
template <typename ItT>
struct DestructIter {
void operator()(ItT &it) { it.~ItT(); }
};

Expand Down
19 changes: 14 additions & 5 deletions mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -881,17 +881,26 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
@parent_reference::@nested_reference
```
}];
let parameters = (ins
StringRefParameter<"">:$rootReference,
ArrayRefParameter<"FlatSymbolRefAttr", "">:$nestedReferences
);
let parameters =
(ins "StringAttr":$rootReference,
ArrayRefParameter<"FlatSymbolRefAttr", "">:$nestedReferences);

let builders = [
AttrBuilderWithInferredContext<
(ins "StringAttr":$rootReference,
"ArrayRef<FlatSymbolRefAttr>":$nestedReferences), [{
return $_get(rootReference.getContext(), rootReference, nestedReferences);
}]>,
];
let extraClassDeclaration = [{
static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
static FlatSymbolRefAttr get(StringAttr value);

/// Returns the name of the fully resolved symbol, i.e. the leaf of the
/// reference path.
StringRef getLeafReference() const;
StringAttr getLeafReference() const;
}];
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,7 @@ def IsNullAttr : AttrConstraint<
class ReferToOp<string opClass> : AttrConstraint<
CPred<"isa_and_nonnull<" # opClass # ">("
"::mlir::SymbolTable::lookupNearestSymbolFrom("
"&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getValue()))">,
"&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getAttr()))">,
"referencing to a '" # opClass # "' symbol">;

//===----------------------------------------------------------------------===//
Expand Down
29 changes: 24 additions & 5 deletions mlir/include/mlir/IR/SymbolInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {

let methods = [
InterfaceMethod<"Returns the name of this symbol.",
"StringRef", "getName", (ins), [{
"StringAttr", "getNameAttr", (ins), [{
// Don't rely on the trait implementation as optional symbol operations
// may override this.
return mlir::SymbolTable::getSymbolName($_op);
Expand All @@ -40,11 +40,10 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
}]
>,
InterfaceMethod<"Sets the name of this symbol.",
"void", "setName", (ins "StringRef":$name), [{}],
"void", "setName", (ins "StringAttr":$name), [{}],
/*defaultImplementation=*/[{
this->getOperation()->setAttr(
mlir::SymbolTable::getSymbolAttrName(),
StringAttr::get(this->getOperation()->getContext(), name));
mlir::SymbolTable::getSymbolAttrName(), name);
}]
>,
InterfaceMethod<"Gets the visibility of this symbol.",
Expand Down Expand Up @@ -122,7 +121,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
symbol 'newSymbol' that are nested within the given operation 'from'.
Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
}],
"LogicalResult", "replaceAllSymbolUses", (ins "StringRef":$newSymbol,
"LogicalResult", "replaceAllSymbolUses", (ins "StringAttr":$newSymbol,
"Operation *":$from), [{}],
/*defaultImplementation=*/[{
return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
Expand Down Expand Up @@ -176,6 +175,16 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
}];

let extraClassDeclaration = [{
/// Convenience version of `getNameAttr` that returns a StringRef.
StringRef getName() {
return getNameAttr().getValue();
}

/// Convenience version of `setName` that take a StringRef.
void setName(StringRef name) {
setName(StringAttr::get(this->getContext(), name));
}

/// Custom classof that handles the case where the symbol is optional.
static bool classof(Operation *op) {
auto *opConcept = getInterfaceFor(op);
Expand All @@ -188,6 +197,16 @@ def Symbol : OpInterface<"SymbolOpInterface"> {

let extraTraitClassDeclaration = [{
using Visibility = mlir::SymbolTable::Visibility;

/// Convenience version of `getNameAttr` that returns a StringRef.
StringRef getName() {
return getNameAttr().getValue();
}

/// Convenience version of `setName` that take a StringRef.
void setName(StringRef name) {
setName(StringAttr::get(this->getContext(), name));
}
}];
}

Expand Down
Loading

0 comments on commit 41d4aa7

Please sign in to comment.