Skip to content

Commit

Permalink
Preserve reflection attrs on functions when wrapping for the native A…
Browse files Browse the repository at this point in the history
…BI. (#16129)

(required for #16130)
  • Loading branch information
benvanik authored Jan 17, 2024
1 parent 91803de commit 6ab1ed8
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ static void populateReflectionAttrs(IREE::ABI::InvocationModel invocationModel,
auto *context = exportOp.getContext();
SmallVector<NamedAttribute> attrs;

if (auto reflectionAttr =
exportOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
attrs.append(reflectionAttr.getValue().begin(),
reflectionAttr.getValue().end());
}

if (auto abiAttr = exportOp->getAttr("iree.abi")) {
attrs.emplace_back(StringAttr::get(context, "iree.abi"), abiAttr);
}
Expand Down Expand Up @@ -487,6 +493,7 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel,

// Populate the reflection attrs based on the original types.
populateReflectionAttrs(invocationModel, exportOp, wrapperOp);
exportOp->removeAttr("iree.reflection");

auto *entryBlock = wrapperOp.addEntryBlock();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ func.func @dynamicEntry(%arg0: tensor<?x8x8x3xf32>, %arg1: tensor<?x8x8x3xf32>)

// -----

// Tests that an existing iree.reflection dictionary is merged with the new
// reflection information.

// CHECK-LABEL: func.func @existingReflection
// CHECK-SAME: iree.reflection =
// CHECK-SAME: iree.abi.declaration = "sync func @existingReflection
// CHECK-SAME: some.attr = 4 : index
// CHECK: func.func private @_existingReflection
// CHECK-NOT: iree.reflection = {some.attr = 4 : index}
func.func @existingReflection() attributes {
iree.reflection = {
some.attr = 4 : index
}
} {
return
}

// -----

// Tests that iree.abi.declaration is added when needed and otherwise the user
// provided value is passed through.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ VMConversionTarget::nestModuleForConversion(mlir::ModuleOp outerModuleOp) {
if (!innerModuleOp) {
innerModuleOp =
ModuleOp::create(outerModuleOp.getLoc(), outerModuleOp.getName());
if (auto reflectionAttr =
outerModuleOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
innerModuleOp->setAttr("iree.reflection", reflectionAttr);
outerModuleOp->removeAttr("iree.reflection");
}
innerModuleOp.getBodyRegion().takeBody(outerModuleOp.getBodyRegion());
outerModuleOp.getBodyRegion().getBlocks().push_back(new Block());
outerModuleOp.push_back(innerModuleOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class ModuleOpConversion : public OpConversionPattern<ModuleOp> {
if (auto version = srcOp->getAttrOfType<IntegerAttr>("vm.version")) {
newModuleOp.setVersionAttr(version);
}
if (auto reflectionAttr =
srcOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
newModuleOp->setAttr("iree.reflection", reflectionAttr);
}
Block *firstCreatedBlock = &newModuleOp.getBodyRegion().front();
rewriter.inlineRegionBefore(srcOp.getBodyRegion(), firstCreatedBlock);
auto blockRange = llvm::make_range(Region::iterator(firstCreatedBlock),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions,
return success();
}

// Returns a list of reflection AttrDefs with entries from |attrs| (or an
// empty/null list).
static iree_vm_AttrDef_vec_ref_t makeAttrDefs(DictionaryAttr attrs,
FlatbufferBuilder &fbb) {
if (!attrs || attrs.empty())
return 0;
SmallVector<iree_vm_AttrDef_ref_t> attrRefs;
for (auto attr : attrs) {
auto key = attr.getName().strref();
auto value = llvm::dyn_cast<StringAttr>(attr.getValue());
if (!value || key.empty())
continue;
// NOTE: if we actually want to keep these we should dedupe them (as the
// keys and likely several of the values are shared across all functions).
auto valueRef = fbb.createString(value.getValue());
auto keyRef = fbb.createString(key);
attrRefs.push_back(iree_vm_AttrDef_create(fbb, keyRef, valueRef));
}
return iree_vm_AttrDef_vec_create(fbb, attrRefs.data(), attrRefs.size());
}

// Creates a FunctionSignatureDef based on the given function metadata.
// Some fields are not used on all signature defs and added only when present on
// the argument objects/attrs.
Expand Down Expand Up @@ -236,24 +257,9 @@ makeFunctionSignatureDef(IREE::VM::FuncOp funcOp,
if (!cconv.has_value())
return {};

// Reflection attributes.
iree_vm_AttrDef_vec_ref_t attrsRef = 0;
if (auto attrs = funcOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
SmallVector<iree_vm_AttrDef_ref_t> attrRefs;
for (auto attr : attrs) {
auto key = attr.getName().strref();
auto value = llvm::dyn_cast<StringAttr>(attr.getValue());
if (!value || key.empty())
continue;
// NOTE: if we actually want to keep these we should dedupe them (as the
// keys and likely several of the values are shared across all functions).
auto valueRef = fbb.createString(value.getValue());
auto keyRef = fbb.createString(key);
attrRefs.push_back(iree_vm_AttrDef_create(fbb, keyRef, valueRef));
}
attrsRef =
iree_vm_AttrDef_vec_create(fbb, attrRefs.data(), attrRefs.size());
}
// Encode reflection attributes.
iree_vm_AttrDef_vec_ref_t attrsRef = makeAttrDefs(
funcOp->getAttrOfType<DictionaryAttr>("iree.reflection"), fbb);

return createFunctionSignatureDef(funcOp.getFunctionType(), typeTable,
cconv.value(), attrsRef, fbb);
Expand Down Expand Up @@ -474,6 +480,10 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions,
return iree_vm_TypeDef_end(fbb);
});

// Encode reflection attributes.
iree_vm_AttrDef_vec_ref_t attrsRef = makeAttrDefs(
moduleOp->getAttrOfType<DictionaryAttr>("iree.reflection"), fbb);

// NOTE: we keep the vectors clustered here so that we can hopefully keep the
// pages mapped at runtime; vector dereferences in FlatBuffers require
// touching these structs to get length/etc and as such we don't want to be
Expand Down Expand Up @@ -525,7 +535,7 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions,
iree_vm_BytecodeModuleDef_version_add(fbb,
moduleOp.getVersion().value_or(0u));
iree_vm_BytecodeModuleDef_requirements_add(fbb, moduleRequirements);
// TODO(benvanik): iree_vm_BytecodeModuleDef_attrs_add
iree_vm_BytecodeModuleDef_attrs_add(fbb, attrsRef);
iree_vm_BytecodeModuleDef_types_add(fbb, typesRef);
iree_vm_BytecodeModuleDef_dependencies_add(fbb, dependenciesRef);
iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ iree_lit_test_suite(
[
"constant_encoding.mlir",
"dependencies.mlir",
"function_attrs.mlir",
"module_encoding_smoke.mlir",
"reflection_attrs.mlir",
],
include = ["*.mlir"],
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ iree_lit_test_suite(
SRCS
"constant_encoding.mlir"
"dependencies.mlir"
"function_attrs.mlir"
"module_encoding_smoke.mlir"
"reflection_attrs.mlir"
TOOLS
FileCheck
iree-compile
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: iree-compile --split-input-file --compile-mode=vm \
// RUN: --iree-vm-bytecode-module-output-format=flatbuffer-text %s | FileCheck %s

// CHECK-LABEL: simple_module
// CHECK: "attrs":
// CHECK: "key": "module_attr_0"
// CHECK: "value": "MODULE_ATTR_0"
// CHECK: "key": "module_attr_1"
// CHECK: "value": "MODULE_ATTR_1"
vm.module @simple_module attributes {
iree.reflection = {
module_attr_0 = "MODULE_ATTR_0",
module_attr_1 = "MODULE_ATTR_1"
}
} {
vm.export @func
// CHECK: "exported_functions":
// CHECK: "attrs":
// CHECK: "key": "func_attr_0"
// CHECK: "value": "FUNC_ATTR_0"
// CHECK: "key": "func_attr_1"
// CHECK: "value": "FUNC_ATTR_1"
vm.func @func(%arg0 : i32) -> i32 attributes {
iree.reflection = {
func_attr_0 = "FUNC_ATTR_0",
func_attr_1 = "FUNC_ATTR_1"
}
} {
vm.return %arg0 : i32
}
}
2 changes: 1 addition & 1 deletion tools/iree-dump-module-main.c
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ static iree_status_t iree_tooling_dump_module_metadata(
fprintf(stdout, "Attributes:\n");
iree_tooling_print_attr_defs(iree_vm_BytecodeModuleDef_attrs(module_def),
2);
fprintf(stdout, "\n\n");
fprintf(stdout, "\n");
}

if (iree_vm_BytecodeModuleDef_types_is_present(module_def)) {
Expand Down

0 comments on commit 6ab1ed8

Please sign in to comment.