Skip to content

Commit

Permalink
Change the printing/parsing behavior for Attributes used in declarati…
Browse files Browse the repository at this point in the history
…ve assembly format

The new form of printing attribute in the declarative assembly is eliding the `#dialect.mnemonic` prefix to only keep the `<....>` part.

Differential Revision: https://reviews.llvm.org/D113873
  • Loading branch information
joker-eph committed Dec 8, 2021
1 parent 63cd184 commit ee09087
Show file tree
Hide file tree
Showing 32 changed files with 574 additions and 170 deletions.
15 changes: 1 addition & 14 deletions mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def ArmSVE_Dialect : Dialect {
vector operations, including a scalable vector type and intrinsics for
some Arm SVE instructions.
}];
let useDefaultTypePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -66,20 +67,6 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
"Type":$elementType
);

let printer = [{
$_printer << "<";
for (int64_t dim : getShape())
$_printer << dim << 'x';
$_printer << getElementType() << '>';
}];

let parser = [{
VectorType vector;
if ($_parser.parseType(vector))
return Type();
return get($_ctxt, vector.getShape(), vector.getElementType());
}];

let extraClassDeclaration = [{
bool hasStaticShape() const {
return llvm::none_of(getShape(), ShapedType::isDynamic);
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/DialectImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,19 @@ struct FieldParser<
AttributeT>> {
static FailureOr<AttributeT> parse(AsmParser &parser) {
AttributeT value;
if (parser.parseAttribute(value))
if (parser.parseCustomAttributeWithFallback(value))
return failure();
return value;
}
};

/// Parse a type.
/// Parse an attribute.
template <typename TypeT>
struct FieldParser<
TypeT, std::enable_if_t<std::is_base_of<Type, TypeT>::value, TypeT>> {
static FailureOr<TypeT> parse(AsmParser &parser) {
TypeT value;
if (parser.parseType(value))
if (parser.parseCustomTypeWithFallback(value))
return failure();
return value;
}
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -2984,6 +2984,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;

// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
Expand Down
174 changes: 169 additions & 5 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,36 @@ class AsmPrinter {
virtual void printType(Type type);
virtual void printAttribute(Attribute attr);

/// Trait to check if `AttrType` provides a `print` method.
template <typename AttrOrType>
using has_print_method =
decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
template <typename AttrOrType>
using detect_has_print_method =
llvm::is_detected<has_print_method, AttrOrType>;

/// Print the provided attribute in the context of an operation custom
/// printer/parser: this will invoke directly the print method on the
/// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
template <typename AttrOrType,
std::enable_if_t<detect_has_print_method<AttrOrType>::value>
*sfinae = nullptr>
void printStrippedAttrOrType(AttrOrType attrOrType) {
if (succeeded(printAlias(attrOrType)))
return;
attrOrType.print(*this);
}

/// SFINAE for printing the provided attribute in the context of an operation
/// custom printer in the case where the attribute does not define a print
/// method.
template <typename AttrOrType,
std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
*sfinae = nullptr>
void printStrippedAttrOrType(AttrOrType attrOrType) {
*this << attrOrType;
}

/// Print the given attribute without its type. The corresponding parser must
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr);
Expand Down Expand Up @@ -102,6 +132,14 @@ class AsmPrinter {
AsmPrinter(const AsmPrinter &) = delete;
void operator=(const AsmPrinter &) = delete;

/// Print the alias for the given attribute, return failure if no alias could
/// be printed.
virtual LogicalResult printAlias(Attribute attr);

/// Print the alias for the given type, return failure if no alias could
/// be printed.
virtual LogicalResult printAlias(Type type);

/// The internal implementation of the printer.
Impl *impl;
};
Expand Down Expand Up @@ -608,6 +646,13 @@ class AsmParser {
/// Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;

/// Parse a custom attribute with the provided callback, unless the next
/// token is `#`, in which case the generic parser is invoked.
virtual ParseResult parseCustomAttributeWithFallback(
Attribute &result, Type type,
function_ref<ParseResult(Attribute &result, Type type)>
parseAttribute) = 0;

/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type = {}) {
Expand Down Expand Up @@ -639,9 +684,9 @@ class AsmParser {
return parseAttribute(result, Type(), attrName, attrs);
}

/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
/// Parse an arbitrary attribute of a given type and populate it in `result`.
/// This also adds the attribute to the specified attribute list with the
/// specified name.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
NamedAttrList &attrs) {
Expand All @@ -661,6 +706,82 @@ class AsmParser {
return success();
}

/// Trait to check if `AttrType` provides a `parse` method.
template <typename AttrType>
using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
std::declval<Type>()));
template <typename AttrType>
using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;

/// Parse a custom attribute of a given type unless the next token is `#`, in
/// which case the generic parser is invoked. The parsed attribute is
/// populated in `result` and also added to the specified attribute list with
/// the specified name.
template <typename AttrType>
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
llvm::SMLoc loc = getCurrentLocation();

// Parse any kind of attribute.
Attribute attr;
if (parseCustomAttributeWithFallback(
attr, type, [&](Attribute &result, Type type) -> ParseResult {
result = AttrType::parse(*this, type);
if (!result)
return failure();
return success();
}))
return failure();

// Check for the right kind of attribute.
result = attr.dyn_cast<AttrType>();
if (!result)
return emitError(loc, "invalid kind of attribute specified");

attrs.append(attrName, result);
return success();
}

/// SFINAE parsing method for Attribute that don't implement a parse method.
template <typename AttrType>
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
return parseAttribute(result, type, attrName, attrs);
}

/// Parse a custom attribute of a given type unless the next token is `#`, in
/// which case the generic parser is invoked. The parsed attribute is
/// populated in `result`.
template <typename AttrType>
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result) {
llvm::SMLoc loc = getCurrentLocation();

// Parse any kind of attribute.
Attribute attr;
if (parseCustomAttributeWithFallback(
attr, {}, [&](Attribute &result, Type type) -> ParseResult {
result = AttrType::parse(*this, type);
return success(!!result);
}))
return failure();

// Check for the right kind of attribute.
result = attr.dyn_cast<AttrType>();
if (!result)
return emitError(loc, "invalid kind of attribute specified");
return success();
}

/// SFINAE parsing method for Attribute that don't implement a parse method.
template <typename AttrType>
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result) {
return parseAttribute(result);
}

/// Parse an arbitrary optional attribute of a given type and return it in
/// result.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
Expand Down Expand Up @@ -740,6 +861,11 @@ class AsmParser {
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;

/// Parse a custom type with the provided callback, unless the next
/// token is `#`, in which case the generic parser is invoked.
virtual ParseResult parseCustomTypeWithFallback(
Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;

/// Parse an optional type.
virtual OptionalParseResult parseOptionalType(Type &result) = 0;

Expand All @@ -753,14 +879,52 @@ class AsmParser {
if (parseType(type))
return failure();

// Check for the right kind of attribute.
// Check for the right kind of type.
result = type.dyn_cast<TypeT>();
if (!result)
return emitError(loc, "invalid kind of type specified");

return success();
}

/// Trait to check if `TypeT` provides a `parse` method.
template <typename TypeT>
using type_has_parse_method =
decltype(TypeT::parse(std::declval<AsmParser &>()));
template <typename TypeT>
using detect_type_has_parse_method =
llvm::is_detected<type_has_parse_method, TypeT>;

/// Parse a custom Type of a given type unless the next token is `#`, in
/// which case the generic parser is invoked. The parsed Type is
/// populated in `result`.
template <typename TypeT>
std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
parseCustomTypeWithFallback(TypeT &result) {
llvm::SMLoc loc = getCurrentLocation();

// Parse any kind of Type.
Type type;
if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
result = TypeT::parse(*this);
return success(!!result);
}))
return failure();

// Check for the right kind of Type.
result = type.dyn_cast<TypeT>();
if (!result)
return emitError(loc, "invalid kind of Type specified");
return success();
}

/// SFINAE parsing method for Type that don't implement a parse method.
template <typename TypeT>
std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
parseCustomTypeWithFallback(TypeT &result) {
return parseType(result);
}

/// Parse a type list.
ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
do {
Expand Down Expand Up @@ -792,7 +956,7 @@ class AsmParser {
if (parseColonType(type))
return failure();

// Check for the right kind of attribute.
// Check for the right kind of type.
result = type.dyn_cast<TypeType>();
if (!result)
return emitError(loc, "invalid kind of type specified");
Expand Down
26 changes: 13 additions & 13 deletions mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,21 @@ void ArmSVEDialect::initialize() {
// ScalableVectorType
//===----------------------------------------------------------------------===//

Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
llvm::SMLoc typeLoc = parser.getCurrentLocation();
{
Type genType;
auto parseResult = generatedTypeParser(parser, "vector", genType);
if (parseResult.hasValue())
return genType;
}
parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
return Type();
void ScalableVectorType::print(AsmPrinter &printer) const {
printer << "<";
for (int64_t dim : getShape())
printer << dim << 'x';
printer << getElementType() << '>';
}

void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
if (failed(generatedTypePrinter(type, os)))
llvm_unreachable("unexpected 'arm_sve' type kind");
Type ScalableVectorType::parse(AsmParser &parser) {
SmallVector<int64_t> dims;
Type eltType;
if (parser.parseLess() ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
parser.parseType(eltType) || parser.parseGreater())
return {};
return ScalableVectorType::get(eltType.getContext(), dims, eltType);
}

//===----------------------------------------------------------------------===//
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/Vector/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ static constexpr const CombiningKind combiningKindsList[] = {
};

void CombiningKindAttr::print(AsmPrinter &printer) const {
printer << "kind<";
printer << "<";
auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
return bitEnumContains(this->getKind(), kind);
});
Expand Down Expand Up @@ -215,10 +215,12 @@ Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,

void VectorDialect::printAttribute(Attribute attr,
DialectAsmPrinter &os) const {
if (auto ck = attr.dyn_cast<CombiningKindAttr>())
if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
os << "kind";
ck.print(os);
else
llvm_unreachable("Unknown attribute type");
return;
}
llvm_unreachable("Unknown attribute type");
}

//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit ee09087

Please sign in to comment.