From 3b7d9dd587a36919489b106c8d1ad41cca11a356 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Fri, 21 Feb 2025 15:41:24 -0800 Subject: [PATCH] [GPU] Add a flag to force workgroup distribution using forall. (#20063) There is on-going work that is trying to make the workgroup distribution use scf.forall op by default. While that lands, add a flag to make this the default. Note that this is a developer only flag and not meant for general usage. Signed-off-by: MaheshRavishankar --- .../src/iree/compiler/Codegen/LLVMGPU/Passes.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 19d4d3bec7d5..aaf45bd28849 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -74,6 +74,11 @@ static llvm::cl::opt clLLVMGPUEnableSharedMemoryReuse( "Enable shared memory reuse in the vector distribute pipeline"), llvm::cl::init(false)); +static llvm::cl::opt clDistributeToWorkgroupsUsingForall( + "iree-llvmgpu-test-distribute-to-workgroups-using-forall", + llvm::cl::desc("Use scf.forall for distribution to workgroups"), + llvm::cl::init(false), llvm::cl::Hidden); + //===----------------------------------------------------------------------===// // Bufferization Configuration //===----------------------------------------------------------------------===// @@ -902,7 +907,8 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, } void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); + tileAndDistributeToWorkgroup( + funcPassManager, /*useForall=*/clDistributeToWorkgroupsUsingForall); funcPassManager.addPass(createRematerializeParallelOpsPass()); funcPassManager.addPass(createConfigTrackingCanonicalizerPass()); funcPassManager.addPass(createGPUTileReductionPass()); @@ -973,8 +979,8 @@ void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) { tileAndBufferize(funcPassManager); // Distribute linalg onto threads within the workgroup. - funcPassManager.addPass( - createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/false)); + funcPassManager.addPass(createLLVMGPUTileAndDistributePass( + /*distributeToWarp=*/clDistributeToWorkgroupsUsingForall)); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass());