Skip to content

Commit

Permalink
[GPU] Add a flag to force workgroup distribution using forall. (#20063)
Browse files Browse the repository at this point in the history
There is on-going work that is trying to make the workgroup distribution
use scf.forall op by default. While that lands, add a flag to make this
the default. Note that this is a developer only flag and not meant for
general usage.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Feb 21, 2025
1 parent d6da252 commit 3b7d9dd
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ static llvm::cl::opt<bool> clLLVMGPUEnableSharedMemoryReuse(
"Enable shared memory reuse in the vector distribute pipeline"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clDistributeToWorkgroupsUsingForall(
"iree-llvmgpu-test-distribute-to-workgroups-using-forall",
llvm::cl::desc("Use scf.forall for distribution to workgroups"),
llvm::cl::init(false), llvm::cl::Hidden);

//===----------------------------------------------------------------------===//
// Bufferization Configuration
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -902,7 +907,8 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
}

void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
tileAndDistributeToWorkgroup(
funcPassManager, /*useForall=*/clDistributeToWorkgroupsUsingForall);
funcPassManager.addPass(createRematerializeParallelOpsPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createGPUTileReductionPass());
Expand Down Expand Up @@ -973,8 +979,8 @@ void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
tileAndBufferize(funcPassManager);

// Distribute linalg onto threads within the workgroup.
funcPassManager.addPass(
createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/false));
funcPassManager.addPass(createLLVMGPUTileAndDistributePass(
/*distributeToWarp=*/clDistributeToWorkgroupsUsingForall));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

Expand Down

0 comments on commit 3b7d9dd

Please sign in to comment.