Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][transform] Fix crash when op is erased during transform.foreach #66357

Merged

Conversation

matthias-springer
Copy link
Member

Fixes a crash when an op, that is mapped to handle that a transform.foreach iterates over, was erased (through the TrackingRewriter). Erasing an op removes it from all mappings and invalidates iterators. This is already taken care of when an op is iterating over payload ops in its apply method, but not when another transform op is erasing a tracked payload op.

@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Changes Fixes a crash when an op, that is mapped to handle that a `transform.foreach` iterates over, was erased (through the `TrackingRewriter`). Erasing an op removes it from all mappings and invalidates iterators. This is already taken care of when an op is iterating over payload ops in its `apply` method, but not when another transform op is erasing a tracked payload op. -- Full diff: https://github.com//pull/66357.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+5)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+14)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+16-2)
  • (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+22)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 114d79555dcef50..efd8d573936c332 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
 
+/// Checks whether the transform op modifies the payload.
+bool doesModifyPayload(transform::TransformOpInterface transform);
+/// Checks whether the transform op reads the payload.
+bool doesReadPayload(transform::TransformOpInterface transform);
+
 /// Populates `consumedArguments` with positions of `block` arguments that are
 /// consumed by the operations in the `block`.
 void getConsumedBlockArguments(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ed987ac4b51646c..00450a1ff8f36cf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload(
   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
 }
 
+bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+  return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
+}
+
+bool transform::doesReadPayload(transform::TransformOpInterface transform) {
+  auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+  return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
+}
+
 void transform::getConsumedBlockArguments(
     Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
   SmallVector<MemoryEffects::EffectInstance> effects;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7bbbbba4134b184..0c9a80699aed09b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
   SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-
-  for (Operation *op : state.getPayloadOps(getTarget())) {
+  // Store payload ops in a vector because ops may be removed from the mapping
+  // by the TrackingRewriter while the iteration is in progress.
+  auto it = state.getPayloadOps(getTarget());
+  SmallVector<Operation *> targets(it.begin(), it.end());
+  for (Operation *op : targets) {
     auto scope = state.make_region_scope(getBody());
     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
       return DiagnosedSilenceableFailure::definiteFailure();
@@ -1152,6 +1155,7 @@ void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   BlockArgument iterVar = getIterationVariable();
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+
         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
       })) {
     consumesHandle(getTarget(), effects);
@@ -1159,6 +1163,16 @@ void transform::ForeachOp::getEffects(
     onlyReadsHandle(getTarget(), effects);
   }
 
+  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+        return doesModifyPayload(cast<TransformOpInterface>(&op));
+      })) {
+    modifiesPayload(effects);
+  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+               return doesReadPayload(cast<TransformOpInterface>(&op));
+             })) {
+    onlyReadsPayload(effects);
+  }
+
   for (Value result : getResults())
     producesHandle(result, effects);
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index db97d0a0887576f..68e3a4851539690 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -691,6 +691,28 @@ transform.with_pdl_patterns {
 
 // -----
 
+// CHECK-LABEL: func @consume_in_foreach()
+//  CHECK-NEXT:   return
+func.func @consume_in_foreach() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  %2 = arith.constant 2 : index
+  %3 = arith.constant 3 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %f = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.foreach %f : !transform.any_op {
+  ^bb2(%arg2: !transform.any_op):
+    // expected-remark @below {{erasing}}
+    transform.test_emit_remark_and_erase_operand %arg2, "erasing" : !transform.any_op
+  }
+}
+
+// -----
+
 func.func @bar() {
   scf.execute_region {
     // expected-remark @below {{transform applied}}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index afd5011f17c6d2b..21f9ff5999a5ed5 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   emitRemark() << getRemark();
   for (Operation *op : state.getPayloadOps(getTarget()))
-    op->erase();
+    rewriter.eraseOp(op);
 
   if (getFailAfterErase())
     return emitSilenceableError() << "silenceable error";

Fixes a crash when an op, that is mapped to handle that a `transform.foreach` iterates over, was erased (through the `TrackingRewriter`). Erasing an op removes it from all mappings and invalidates iterators. This is already taken care of when an op is iterating over payload ops in its `apply` method, but not when another transform op is erasing a tracked payload op.
@matthias-springer matthias-springer merged commit 4f63252 into llvm:main Sep 14, 2023
@ingomueller-net
Copy link
Contributor

Awesome, thanks!

kstoimenov pushed a commit to kstoimenov/llvm-project that referenced this pull request Sep 14, 2023
llvm#66357)

Fixes a crash when an op, that is mapped to handle that a
`transform.foreach` iterates over, was erased (through the
`TrackingRewriter`). Erasing an op removes it from all mappings and
invalidates iterators. This is already taken care of when an op is
iterating over payload ops in its `apply` method, but not when another
transform op is erasing a tracked payload op.
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
llvm#66357)

Fixes a crash when an op, that is mapped to handle that a
`transform.foreach` iterates over, was erased (through the
`TrackingRewriter`). Erasing an op removes it from all mappings and
invalidates iterators. This is already taken care of when an op is
iterating over payload ops in its `apply` method, but not when another
transform op is erasing a tracked payload op.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants