From 612e14e5b4ac0aecf3650b9d8f7d9b4826dc7525 Mon Sep 17 00:00:00 2001 From: Vinicius Couto Espindola Date: Fri, 7 Jul 2023 10:43:31 -0300 Subject: [PATCH] [CIR][Lowering][Bugfix] Refactor for loop lowering This refactor merges the lowering logic of all the different kinds of loops into a single function. It also removes unnecessary LIT tests that validate LLVM dialect to LLVM IR lowering, as this functionality is not within CIR's scope. Fixes #153 ghstack-source-id: ebaab859057a6d81f1978fd88701c28402712562 Pull Request resolved: https://github.com/llvm/clangir/pull/156 --- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 110 ++++-------------- clang/test/CIR/Lowering/dot.cir | 100 +++------------- clang/test/CIR/Lowering/loop.cir | 49 ++------ 3 files changed, 50 insertions(+), 209 deletions(-) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index b42d91ce941d6..dc2d5f037c29f 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -135,10 +135,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern { return mlir::success(); } - mlir::LogicalResult - rewriteWhileLoop(mlir::cir::LoopOp loopOp, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter, - mlir::cir::LoopOpKind kind) const { + mlir::LogicalResult rewriteLoop(mlir::cir::LoopOp loopOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter, + mlir::cir::LoopOpKind kind) const { auto *currentBlock = rewriter.getInsertionBlock(); auto *continueBlock = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); @@ -150,16 +149,24 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern { if (fetchCondRegionYields(condRegion, yieldToBody, yieldToCont).failed()) return loopOp.emitError("failed to fetch yields in cond region"); - // Fetch required info from the condition region. + // Fetch required info from the body region. auto &bodyRegion = loopOp.getBody(); auto &bodyFrontBlock = bodyRegion.front(); auto bodyYield = dyn_cast(bodyRegion.back().getTerminator()); assert(bodyYield && "unstructured while loops are NYI"); + // Fetch required info from the step region. + auto &stepRegion = loopOp.getStep(); + auto &stepFrontBlock = stepRegion.front(); + auto stepYield = + dyn_cast(stepRegion.back().getTerminator()); + // Move loop op region contents to current CFG. rewriter.inlineRegionBefore(condRegion, continueBlock); rewriter.inlineRegionBefore(bodyRegion, continueBlock); + if (kind == LoopKind::For) // Ignore step if not a for-loop. + rewriter.inlineRegionBefore(stepRegion, continueBlock); // Set loop entry point to condition or to body in do-while cases. rewriter.setInsertionPointToEnd(currentBlock); @@ -174,9 +181,16 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern { rewriter.setInsertionPoint(yieldToBody); rewriter.replaceOpWithNewOp(yieldToBody, &bodyFrontBlock); - // Branch from body to condition. + // Branch from body to condition or to step on for-loop cases. rewriter.setInsertionPoint(bodyYield); - rewriter.replaceOpWithNewOp(bodyYield, &condFrontBlock); + auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock); + rewriter.replaceOpWithNewOp(bodyYield, &bodyExit); + + // Is a for loop: branch from step to condition. + if (kind == LoopKind::For) { + rewriter.setInsertionPoint(stepYield); + rewriter.replaceOpWithNewOp(stepYield, &condFrontBlock); + } // Remove the loop op. rewriter.eraseOp(loopOp); @@ -188,91 +202,11 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern { mlir::ConversionPatternRewriter &rewriter) const override { switch (loopOp.getKind()) { case LoopKind::For: - break; case LoopKind::While: case LoopKind::DoWhile: - return rewriteWhileLoop(loopOp, adaptor, rewriter, loopOp.getKind()); + return rewriteLoop(loopOp, adaptor, rewriter, loopOp.getKind()); } - auto loc = loopOp.getLoc(); - - auto *currentBlock = rewriter.getInsertionBlock(); - auto *remainingOpsBlock = - rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); - mlir::Block *continueBlock; - if (loopOp->getResults().size() == 0) - continueBlock = remainingOpsBlock; - else - llvm_unreachable("NYI"); - - auto &condRegion = loopOp.getCond(); - auto &condFrontBlock = condRegion.front(); - - auto &stepRegion = loopOp.getStep(); - auto &stepFrontBlock = stepRegion.front(); - auto &stepBackBlock = stepRegion.back(); - - auto &bodyRegion = loopOp.getBody(); - auto &bodyFrontBlock = bodyRegion.front(); - auto &bodyBackBlock = bodyRegion.back(); - - bool rewroteContinue = false; - bool rewroteBreak = false; - - for (auto &bb : condRegion) { - if (rewroteContinue && rewroteBreak) - break; - - if (auto yieldOp = dyn_cast(bb.getTerminator())) { - rewriter.setInsertionPointToEnd(yieldOp->getBlock()); - if (yieldOp.getKind().has_value()) { - switch (yieldOp.getKind().value()) { - case mlir::cir::YieldOpKind::Break: - case mlir::cir::YieldOpKind::Fallthrough: - case mlir::cir::YieldOpKind::NoSuspend: - llvm_unreachable("None of these should be present"); - case mlir::cir::YieldOpKind::Continue:; - rewriter.replaceOpWithNewOp( - yieldOp, yieldOp.getArgs(), &stepFrontBlock); - rewroteContinue = true; - } - } else { - rewriter.replaceOpWithNewOp( - yieldOp, yieldOp.getArgs(), continueBlock); - rewroteBreak = true; - } - } - } - - rewriter.inlineRegionBefore(condRegion, continueBlock); - - rewriter.inlineRegionBefore(stepRegion, continueBlock); - - if (auto stepYieldOp = - dyn_cast(stepBackBlock.getTerminator())) { - rewriter.setInsertionPointToEnd(stepYieldOp->getBlock()); - rewriter.replaceOpWithNewOp( - stepYieldOp, stepYieldOp.getArgs(), &bodyFrontBlock); - } else { - llvm_unreachable("What are we terminating with?"); - } - - rewriter.inlineRegionBefore(bodyRegion, continueBlock); - - if (auto bodyYieldOp = - dyn_cast(bodyBackBlock.getTerminator())) { - rewriter.setInsertionPointToEnd(bodyYieldOp->getBlock()); - rewriter.replaceOpWithNewOp( - bodyYieldOp, bodyYieldOp.getArgs(), &condFrontBlock); - } else { - llvm_unreachable("What are we terminating with?"); - } - - rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, mlir::ValueRange(), &condFrontBlock); - - rewriter.replaceOp(loopOp, continueBlock->getArguments()); - return mlir::success(); } }; diff --git a/clang/test/CIR/Lowering/dot.cir b/clang/test/CIR/Lowering/dot.cir index 22407d61e73e1..238dcdc9abde7 100644 --- a/clang/test/CIR/Lowering/dot.cir +++ b/clang/test/CIR/Lowering/dot.cir @@ -1,5 +1,5 @@ -// RUN: cir-tool %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR -// RUN: cir-tool %s -cir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM +// RUN: cir-tool %s -cir-to-llvm -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s -check-prefix=MLIR !s32i = !cir.int module { @@ -95,24 +95,24 @@ module { // MLIR-NEXT: ^bb4: // pred: ^bb2 // MLIR-NEXT: llvm.br ^bb7 // MLIR-NEXT: ^bb5: // pred: ^bb3 -// MLIR-NEXT: %22 = llvm.load %12 : !llvm.ptr -// MLIR-NEXT: %23 = llvm.mlir.constant(1 : i32) : i32 -// MLIR-NEXT: %24 = llvm.add %22, %23 : i32 -// MLIR-NEXT: llvm.store %24, %12 : i32, !llvm.ptr +// MLIR-NEXT: %22 = llvm.load %1 : !llvm.ptr +// MLIR-NEXT: %23 = llvm.load %12 : !llvm.ptr +// MLIR-NEXT: %24 = llvm.getelementptr %22[%23] : (!llvm.ptr, i32) -> !llvm.ptr +// MLIR-NEXT: %25 = llvm.load %24 : !llvm.ptr +// MLIR-NEXT: %26 = llvm.load %3 : !llvm.ptr +// MLIR-NEXT: %27 = llvm.load %12 : !llvm.ptr +// MLIR-NEXT: %28 = llvm.getelementptr %26[%27] : (!llvm.ptr, i32) -> !llvm.ptr +// MLIR-NEXT: %29 = llvm.load %28 : !llvm.ptr +// MLIR-NEXT: %30 = llvm.fmul %25, %29 : f64 +// MLIR-NEXT: %31 = llvm.load %9 : !llvm.ptr +// MLIR-NEXT: %32 = llvm.fadd %31, %30 : f64 +// MLIR-NEXT: llvm.store %32, %9 : f64, !llvm.ptr // MLIR-NEXT: llvm.br ^bb6 // MLIR-NEXT: ^bb6: // pred: ^bb5 -// MLIR-NEXT: %25 = llvm.load %1 : !llvm.ptr -// MLIR-NEXT: %26 = llvm.load %12 : !llvm.ptr -// MLIR-NEXT: %27 = llvm.getelementptr %25[%26] : (!llvm.ptr, i32) -> !llvm.ptr -// MLIR-NEXT: %28 = llvm.load %27 : !llvm.ptr -// MLIR-NEXT: %29 = llvm.load %3 : !llvm.ptr -// MLIR-NEXT: %30 = llvm.load %12 : !llvm.ptr -// MLIR-NEXT: %31 = llvm.getelementptr %29[%30] : (!llvm.ptr, i32) -> !llvm.ptr -// MLIR-NEXT: %32 = llvm.load %31 : !llvm.ptr -// MLIR-NEXT: %33 = llvm.fmul %28, %32 : f64 -// MLIR-NEXT: %34 = llvm.load %9 : !llvm.ptr -// MLIR-NEXT: %35 = llvm.fadd %34, %33 : f64 -// MLIR-NEXT: llvm.store %35, %9 : f64, !llvm.ptr +// MLIR-NEXT: %33 = llvm.load %12 : !llvm.ptr +// MLIR-NEXT: %34 = llvm.mlir.constant(1 : i32) : i32 +// MLIR-NEXT: %35 = llvm.add %33, %34 : i32 +// MLIR-NEXT: llvm.store %35, %12 : i32, !llvm.ptr // MLIR-NEXT: llvm.br ^bb2 // MLIR-NEXT: ^bb7: // pred: ^bb4 // MLIR-NEXT: llvm.br ^bb8 @@ -123,67 +123,3 @@ module { // MLIR-NEXT: llvm.return %37 : f64 // MLIR-NEXT: } // MLIR-NEXT: } - -// LLVM: define double @dot(ptr %0, ptr %1, i32 %2) { -// LLVM-NEXT: %4 = alloca ptr, i64 1, align 8 -// LLVM-NEXT: %5 = alloca ptr, i64 1, align 8 -// LLVM-NEXT: %6 = alloca i32, i64 1, align 4 -// LLVM-NEXT: %7 = alloca double, i64 1, align 8 -// LLVM-NEXT: %8 = alloca double, i64 1, align 8 -// LLVM-NEXT: store ptr %0, ptr %4, align 8 -// LLVM-NEXT: store ptr %1, ptr %5, align 8 -// LLVM-NEXT: store i32 %2, ptr %6, align 4 -// LLVM-NEXT: store double 0.000000e+00, ptr %8, align 8 -// LLVM-NEXT: br label %9 -// LLVM-EMPTY: -// LLVM-NEXT: 9: ; preds = %3 -// LLVM-NEXT: %10 = alloca i32, i64 1, align 4 -// LLVM-NEXT: store i32 0, ptr %10, align 4 -// LLVM-NEXT: br label %11 -// LLVM-EMPTY: -// LLVM-NEXT: 11: ; preds = %24, %9 -// LLVM-NEXT: %12 = load i32, ptr %10, align 4 -// LLVM-NEXT: %13 = load i32, ptr %6, align 4 -// LLVM-NEXT: %14 = icmp slt i32 %12, %13 -// LLVM-NEXT: %15 = zext i1 %14 to i32 -// LLVM-NEXT: %16 = icmp ne i32 %15, 0 -// LLVM-NEXT: %17 = zext i1 %16 to i8 -// LLVM-NEXT: %18 = trunc i8 %17 to i1 -// LLVM-NEXT: br i1 %18, label %19, label %20 -// LLVM-EMPTY: -// LLVM-NEXT: 19: ; preds = %11 -// LLVM-NEXT: br label %21 -// LLVM-EMPTY: -// LLVM-NEXT: 20: ; preds = %11 -// LLVM-NEXT: br label %36 -// LLVM-EMPTY: -// LLVM-NEXT: 21: ; preds = %19 -// LLVM-NEXT: %22 = load i32, ptr %10, align 4 -// LLVM-NEXT: %23 = add i32 %22, 1 -// LLVM-NEXT: store i32 %23, ptr %10, align 4 -// LLVM-NEXT: br label %24 -// LLVM-EMPTY: -// LLVM-NEXT: 24: ; preds = %21 -// LLVM-NEXT: %25 = load ptr, ptr %4, align 8 -// LLVM-NEXT: %26 = load i32, ptr %10, align 4 -// LLVM-NEXT: %27 = getelementptr double, ptr %25, i32 %26 -// LLVM-NEXT: %28 = load double, ptr %27, align 8 -// LLVM-NEXT: %29 = load ptr, ptr %5, align 8 -// LLVM-NEXT: %30 = load i32, ptr %10, align 4 -// LLVM-NEXT: %31 = getelementptr double, ptr %29, i32 %30 -// LLVM-NEXT: %32 = load double, ptr %31, align 8 -// LLVM-NEXT: %33 = fmul double %28, %32 -// LLVM-NEXT: %34 = load double, ptr %8, align 8 -// LLVM-NEXT: %35 = fadd double %34, %33 -// LLVM-NEXT: store double %35, ptr %8, align 8 -// LLVM-NEXT: br label %11 -// LLVM-EMPTY: -// LLVM-NEXT: 36: ; preds = %20 -// LLVM-NEXT: br label %37 -// LLVM-EMPTY: -// LLVM-NEXT: 37: ; preds = %36 -// LLVM-NEXT: %38 = load double, ptr %8, align 8 -// LLVM-NEXT: store double %38, ptr %7, align 8 -// LLVM-NEXT: %39 = load double, ptr %7, align 8 -// LLVM-NEXT: ret double %39 -// LLVM-NEXT: } diff --git a/clang/test/CIR/Lowering/loop.cir b/clang/test/CIR/Lowering/loop.cir index ffadc539b323b..e0a0d9840243e 100644 --- a/clang/test/CIR/Lowering/loop.cir +++ b/clang/test/CIR/Lowering/loop.cir @@ -1,9 +1,9 @@ -// RUN: cir-tool %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR -// RUN: cir-tool %s -cir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM +// RUN: cir-tool %s -cir-to-llvm -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s -check-prefix=MLIR !s32i = !cir.int module { - cir.func @foo() { + cir.func @testFor() { %0 = cir.alloca !s32i, cir.ptr , ["i", init] {alignment = 4 : i64} %1 = cir.const(#cir.int<0> : !s32i) : !s32i cir.store %1, %0 : !s32i, cir.ptr @@ -29,12 +29,13 @@ module { } // MLIR: module { -// MLIR-NEXT: llvm.func @foo() { +// MLIR-NEXT: llvm.func @testFor() { // MLIR-NEXT: %0 = llvm.mlir.constant(1 : index) : i64 // MLIR-NEXT: %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr // MLIR-NEXT: %2 = llvm.mlir.constant(0 : i32) : i32 // MLIR-NEXT: llvm.store %2, %1 : i32, !llvm.ptr // MLIR-NEXT: llvm.br ^bb1 +// ============= Condition block ============= // MLIR-NEXT: ^bb1: // 2 preds: ^bb0, ^bb5 // MLIR-NEXT: %3 = llvm.load %1 : !llvm.ptr // MLIR-NEXT: %4 = llvm.mlir.constant(10 : i32) : i32 @@ -49,51 +50,21 @@ module { // MLIR-NEXT: llvm.br ^bb4 // MLIR-NEXT: ^bb3: // pred: ^bb1 // MLIR-NEXT: llvm.br ^bb6 +// ============= Body block ============= // MLIR-NEXT: ^bb4: // pred: ^bb2 +// MLIR-NEXT: llvm.br ^bb5 +// ============= Step block ============= +// MLIR-NEXT: ^bb5: // pred: ^bb4 // MLIR-NEXT: %11 = llvm.load %1 : !llvm.ptr // MLIR-NEXT: %12 = llvm.mlir.constant(1 : i32) : i32 // MLIR-NEXT: %13 = llvm.add %11, %12 : i32 // MLIR-NEXT: llvm.store %13, %1 : i32, !llvm.ptr -// MLIR-NEXT: llvm.br ^bb5 -// MLIR-NEXT: ^bb5: // pred: ^bb4 // MLIR-NEXT: llvm.br ^bb1 +// ============= Exit block ============= // MLIR-NEXT: ^bb6: // pred: ^bb3 // MLIR-NEXT: llvm.return // MLIR-NEXT: } -// LLVM: define void @foo() { -// LLVM-NEXT: %1 = alloca i32, i64 1, align 4 -// LLVM-NEXT: store i32 0, ptr %1, align 4 -// LLVM-NEXT: br label %2 -// LLVM-EMPTY: -// LLVM-NEXT: 2: -// LLVM-NEXT: %3 = load i32, ptr %1, align 4 -// LLVM-NEXT: %4 = icmp slt i32 %3, 10 -// LLVM-NEXT: %5 = zext i1 %4 to i32 -// LLVM-NEXT: %6 = icmp ne i32 %5, 0 -// LLVM-NEXT: %7 = zext i1 %6 to i8 -// LLVM-NEXT: %8 = trunc i8 %7 to i1 -// LLVM-NEXT: br i1 %8, label %9, label %10 -// LLVM-EMPTY: -// LLVM-NEXT: 9: -// LLVM-NEXT: br label %11 -// LLVM-EMPTY: -// LLVM-NEXT: 10: -// LLVM-NEXT: br label %15 -// LLVM-EMPTY: -// LLVM-NEXT: 11: -// LLVM-NEXT: %12 = load i32, ptr %1, align 4 -// LLVM-NEXT: %13 = add i32 %12, 1 -// LLVM-NEXT: store i32 %13, ptr %1, align 4 -// LLVM-NEXT: br label %14 -// LLVM-EMPTY: -// LLVM-NEXT: 14: -// LLVM-NEXT: br label %2 -// LLVM-EMPTY: -// LLVM-NEXT: 15: -// LLVM-NEXT: ret void -// LLVM-NEXT: } - // Test while cir.loop operation lowering. cir.func @testWhile(%arg0: !s32i) { %0 = cir.alloca !s32i, cir.ptr , ["i", init] {alignment = 4 : i64}