Skip to content

Commit

Permalink
Removing duplicate results in closure regions. (#20082)
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik authored Feb 25, 2025
1 parent 3459998 commit 03be423
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 8 deletions.
6 changes: 5 additions & 1 deletion compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,11 @@ def FLOW_DispatchTensorStoreOp : FLOW_Op<"dispatch.tensor.store", [
let hasCanonicalizer = 1;
}

def FLOW_ReturnOp : FLOW_Op<"return", [Pure, ReturnLike, Terminator]> {
def FLOW_ReturnOp : FLOW_Op<"return", [
Pure,
ReturnLike,
Terminator,
]> {
let summary = [{return from a flow.dispatch_region}];
let description = [{
Returns the given values from the region and back to the host code.
Expand Down
15 changes: 12 additions & 3 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,15 @@ static void eraseStreamRegionResults(Region &region,
auto yieldOp = dyn_cast<IREE::Stream::YieldOp>(block.getTerminator());
if (!yieldOp)
continue;
llvm::SmallVector<Value> newOperands;
// HACK: there's no good way of updating the operand and size together today
// - we should add a helper to the ClosureYieldOpInterface that checks for
// size/shape aware traits and does this automatically.
for (auto i : llvm::reverse(excludedResultIndices)) {
yieldOp.getResourceOperandsMutable().erase(i);
yieldOp.getResourceOperandSizesMutable().erase(i);
unsigned resourceIndex = i;
unsigned resourceSizeIndex =
yieldOp.getResourceOperandsMutable().size() + i;
yieldOp->eraseOperand(resourceSizeIndex);
yieldOp->eraseOperand(resourceIndex);
}
}
}
Expand Down Expand Up @@ -4308,6 +4313,10 @@ YieldOp::getMutableSuccessorOperands(RegionBranchPoint point) {
return getResourceOperandsMutable();
}

MutableOperandRange YieldOp::getClosureResultsMutable() {
return getResourceOperandsMutable();
}

} // namespace mlir::iree_compiler::IREE::Stream

//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4363,6 +4363,9 @@ def Stream_YieldOp : Stream_Op<"yield", [
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
Terminator,
SameVariadicOperandSize,
DeclareOpInterfaceMethods<Util_ClosureYieldOpInterface, [
"getClosureResultsMutable",
]>,
Util_SizeAwareOp,
]> {
let summary = [{yields stream values from an execution region}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,45 @@ util.func private @ElideUnusedAsyncExecuteOp(%arg0: !stream.resource<*>, %arg1:
util.return
}


// -----

// CHECK-LABEL: @FoldAsyncExecuteDuplicateResults
// CHECK-SAME: (%[[SPLAT_A_SIZE:.+]]: index, %[[SPLAT_A_VALUE:.+]]: i32, %[[SPLAT_B_SIZE:.+]]: index, %[[SPLAT_B_VALUE:.+]]: i32)
util.func private @FoldAsyncExecuteDuplicateResults(%splat_a_size: index, %splat_a_value: i32, %splat_b_size: index, %splat_b_value: i32) -> (!stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.timepoint) {
// CHECK: %[[RESULTS:.+]]:2, %[[TIMEPOINT:.+]] = stream.async.execute with() -> (!stream.resource<*>{%[[SPLAT_A_SIZE]]}, !stream.resource<*>{%[[SPLAT_B_SIZE]]}) {
%results:3, %timepoint = stream.async.execute with() -> (!stream.resource<*>{%splat_a_size}, !stream.resource<*>{%splat_b_size}, !stream.resource<*>{%splat_a_size}) {
// CHECK: %[[SPLAT_A:.+]] = stream.async.splat %[[SPLAT_A_VALUE]]
%splat_a = stream.async.splat %splat_a_value : i32 -> !stream.resource<*>{%splat_a_size}
// CHECK: %[[SPLAT_B:.+]] = stream.async.splat %[[SPLAT_B_VALUE]]
%splat_b = stream.async.splat %splat_b_value : i32 -> !stream.resource<*>{%splat_b_size}
// CHECK: stream.yield %[[SPLAT_A]], %[[SPLAT_B]] : !stream.resource<*>{%[[SPLAT_A_SIZE]]}, !stream.resource<*>{%[[SPLAT_B_SIZE]]}
stream.yield %splat_a, %splat_b, %splat_a : !stream.resource<*>{%splat_a_size}, !stream.resource<*>{%splat_b_size}, !stream.resource<*>{%splat_a_size}
} => !stream.timepoint
// CHECK: util.return %[[RESULTS]]#0, %[[RESULTS]]#1, %[[RESULTS]]#0, %[[TIMEPOINT]]
util.return %results#0, %results#1, %results#2, %timepoint : !stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.timepoint
}

// -----

// CHECK-LABEL: @FoldAsyncExecuteTiedDuplicateResults
// CHECK-SAME: (%[[TARGET:.+]]: !stream.resource<*>, %[[TARGET_SIZE:.+]]: index, %[[SPLAT_SIZE:.+]]: index, %[[SPLAT_VALUE:.+]]: i32)
util.func private @FoldAsyncExecuteTiedDuplicateResults(%target: !stream.resource<*>, %target_size: index, %splat_size: index, %splat_value: i32) -> (!stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.timepoint) {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
// CHECK: %[[RESULTS:.+]]:2, %[[TIMEPOINT:.+]] = stream.async.execute with({{.+}}) -> (%[[TARGET]]{%[[TARGET_SIZE]]}, !stream.resource<*>{%[[SPLAT_SIZE]]}) {
%results:3, %timepoint = stream.async.execute with(%target as %target_capture: !stream.resource<*>{%target_size}) -> (%target as !stream.resource<*>{%target_size}, !stream.resource<*>{%splat_size}, %target as !stream.resource<*>{%target_size}) {
// CHECK: %[[TARGET_FILL:.+]] = stream.async.fill
%target_fill = stream.async.fill %splat_value, %target_capture[%c0 to %c128 for %c128] : i32 -> %target_capture as !stream.resource<*>{%target_size}
// CHECK: %[[SPLAT:.+]] = stream.async.splat %[[SPLAT_VALUE]]
%splat = stream.async.splat %splat_value : i32 -> !stream.resource<*>{%splat_size}
// CHECK: stream.yield %[[TARGET_FILL]], %[[SPLAT]] : !stream.resource<*>{%[[TARGET_SIZE]]}, !stream.resource<*>{%[[SPLAT_SIZE]]}
stream.yield %target_fill, %splat, %target_fill : !stream.resource<*>{%target_size}, !stream.resource<*>{%splat_size}, !stream.resource<*>{%target_size}
} => !stream.timepoint
// CHECK: util.return %[[RESULTS]]#0, %[[RESULTS]]#1, %[[RESULTS]]#0, %[[TIMEPOINT]]
util.return %results#0, %results#1, %results#2, %timepoint : !stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.timepoint
}

// -----

// CHECK-LABEL: @TieRegionResultsAsyncConcurrentOp
Expand Down
109 changes: 105 additions & 4 deletions compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,88 @@ void eraseRegionResults(Region &region,
}
}

// Finds any yielded region results that are duplicates of each other (as
// defined by having the same SSA value). Returns a map of result indices to the
// result index that is the leader of an equivalence class on each SSA value.
// Example:
// yield %a, %b, %a, %b -> (0, 1, 0, 1)
// Note that a result index in the map with a value of its own index indicates
// a result that is not duplicated and that must be preserved.
static SmallVector<unsigned> findDuplicateRegionResults(Region &region) {
// Gather all yield ops in the closure.
SmallVector<IREE::Util::ClosureYieldOpInterface> yieldOps;
for (auto &block : region.getBlocks()) {
if (block.empty()) {
continue;
}
auto *terminatorOp = block.getTerminator();
if (auto yieldOp = dyn_cast_if_present<IREE::Util::ClosureYieldOpInterface>(
terminatorOp)) {
yieldOps.push_back(yieldOp);
}
}
if (yieldOps.empty()) {
return {};
}
const unsigned resultCount =
yieldOps.front().getClosureResultsMutable().size();

// Build a map of result indices to its base duplicate for each yield site.
// Base/non-duplicated values will be identity.
// Example:
// yield %a, %b, %a, %b -> (0, 1, 0, 1)
static const int kUnassigned = -1;
SmallVector<SmallVector<unsigned>> dupeIndexMaps(yieldOps.size());
for (auto yieldOp : llvm::enumerate(yieldOps)) {
auto &dupeIndexMap = dupeIndexMaps[yieldOp.index()];
dupeIndexMap.resize(resultCount, kUnassigned);
auto operands = yieldOp.value().getClosureResultsMutable();
for (unsigned i = 0; i < operands.size(); ++i) {
for (unsigned j = 0; j < i; ++j) {
if (operands[j].get() == operands[i].get()) {
dupeIndexMap[i] = j;
break;
}
}
}
}

// Per-result now find which are consistently duplicated.
// Note that we may have multiple yield ops and we have to ensure that one
// returning duplicates does not influence others that may not be.
llvm::BitVector sameValues(resultCount);
llvm::BitVector deadResultsMap(resultCount);
auto uniformDupeIndexMap =
llvm::to_vector(llvm::seq(0u, resultCount)); // old -> new
for (unsigned idx = 0; idx < resultCount; ++idx) {
if (deadResultsMap.test(idx))
continue;
// Each bit represents a result that duplicates the result at idx.
// We walk all the sites and AND their masks together to get the safe
// set of duplicate results.
// Example for %0: yield %a, %b, %a -> b001
// Example for %1: yield %a, %b, %a -> b000
sameValues.set(); // note reused
for (auto &dupeIndexMap : dupeIndexMaps) {
for (unsigned i = 0; i < resultCount; ++i) {
if (i == idx || dupeIndexMap[i] != idx) {
sameValues.reset(i);
}
}
}
if (sameValues.none()) {
uniformDupeIndexMap[idx] = idx;
continue;
}
deadResultsMap |= sameValues;
uniformDupeIndexMap[idx] = idx;
for (auto dupeIdx : sameValues.set_bits()) {
uniformDupeIndexMap[dupeIdx] = idx;
}
}
return uniformDupeIndexMap;
}

// Returns true if |constantOp| represents a (logically) small constant value
// that can be inlined into a closure.
//
Expand Down Expand Up @@ -227,30 +309,43 @@ LogicalResult optimizeClosureLikeOp(const ClosureOptimizationOptions &options,
for (auto opArg : llvm::enumerate(closureOp.getClosureOperands())) {
auto blockArg = entryBlock.getArgument(opArg.index());
if (blockArg.use_empty()) {
// Not used - Drop.
// Not used - drop.
elidedOperands.push_back(opArg.index());
blockArgReplacements[opArg.index()] = BlockArgument();
continue;
}
auto existingIt = argToBlockMap.find(opArg.value());
if (existingIt == argToBlockMap.end()) {
// Not found - Record for deduping.
// Not found - record for deduping.
argToBlockMap.insert(std::make_pair(opArg.value(), blockArg));
} else {
// Found - Replace.
// Found - replace.
elidedOperands.push_back(opArg.index());
blockArgReplacements[opArg.index()] = existingIt->second;
}
}

// Check for unused results.
// Find duplicate results (where all yield sites return a duplicate value) as
// a map from result index to the result index it is a duplicate of. Results
// that are not duplicates (or are the base value) have an identity entry.
auto duplicateResultMap =
findDuplicateRegionResults(closureOp.getClosureBodyRegion());

// Check for unused or duplicate results.
SmallVector<Value> preservedResults;
SmallVector<unsigned> elidedResults;
SmallVector<std::pair<Value, Value>> resultReplacements;
for (auto result : llvm::enumerate(closureOp.getClosureResults())) {
// You can drop a result if the use is empty and not read via a tie.
auto access = closureOp.getResultAccess(result.index());
if (result.value().use_empty() && !access.isRead) {
elidedResults.push_back(result.index());
} else if (!duplicateResultMap.empty() &&
duplicateResultMap[result.index()] != result.index()) {
elidedResults.push_back(result.index());
resultReplacements.push_back(std::make_pair(
result.value(),
closureOp.getClosureResults()[duplicateResultMap[result.index()]]));
} else {
preservedResults.push_back(result.value());
}
Expand Down Expand Up @@ -281,6 +376,12 @@ LogicalResult optimizeClosureLikeOp(const ClosureOptimizationOptions &options,
}
}

// Replace duplicate results prior to cloning - the SSA values will no longer
// exist afterward.
for (auto [oldResult, newResult] : resultReplacements) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}

// Clone the op with the elidable operands and results removed.
auto newOp = closureOp.cloneReplacementExcludingOperandsAndResults(
elidedOperands, elidedResults, rewriter);
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,31 @@ def Util_ClosureOpInterface : OpInterface<"ClosureOpInterface"> {
];
}

def Util_ClosureYieldOpInterface : OpInterface<"ClosureYieldOpInterface"> {
let cppNamespace = "::mlir::iree_compiler::IREE::Util";

let description = [{
Interface for ops that yield results from a closure. These also generally
implement RegionBranchTerminatorOpInterface and that should be used when
possible for better interoperability with upstream code. Ops that have
special operand handling such as shape- and size-aware ops need to use this
interface to provide the operand range representing the yielded values.
}];

let methods = [
InterfaceMethod<
/*desc=*/[{Returns yielded closure result values.}],
/*retTy=*/"MutableOperandRange",
/*methodName=*/"getClosureResultsMutable",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return MutableOperandRange($_op, 0, $_op->getNumOperands());
}]
>,
];
}

//===----------------------------------------------------------------------===//
// IREE::Util::InferIntDivisibilityOpInterface
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 03be423

Please sign in to comment.