From b8ab8252df1b444a9f98822f19cb3f0bca96164c Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 4 Jun 2024 19:56:05 +0200 Subject: [PATCH] Fix Brutus example on nightly. (#75) * add v17 to API.jl * linetable is stored in debuginfo field now * using internal IRShow method to get linenumber * I forgot the dialects! * small error * Update examples/brutus.jl Co-authored-by: Paul Berg * Update examples/brutus.jl Co-authored-by: Paul Berg * MLIR17: fix brutus example location info, update Pass.jl, update executionengine example --------- Co-authored-by: Paul Berg --- examples/brutus.jl | 12 ++++++++++-- src/API/API.jl | 4 ++-- src/Dialects/Dialects.jl | 6 +++--- src/IR/Pass.jl | 6 +++++- test/executionengine.jl | 13 +++++++++++-- 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/examples/brutus.jl b/examples/brutus.jl index f1203656..5ebc597d 100644 --- a/examples/brutus.jl +++ b/examples/brutus.jl @@ -143,7 +143,16 @@ function code_mlir(f, types) for sidx in bb.stmts stmt = ir.stmts[sidx] inst = stmt[:inst] - line = ir.linetable[stmt[:line]+1] + line = @static if VERSION <= v"1.11" + ir.linetable[stmt[:line]+1] + else + lineinfonode = Base.IRShow.buildLineInfoNode(ir.debuginfo, :var"n/a", sidx) + if !isempty(lineinfonode) + last(lineinfonode) + else + (; ((:file, :line) .=> Base.IRShow.debuginfo_firstline(ir.debuginfo))...) + end + end if Meta.isexpr(inst, :call) val_type = stmt[:type] @@ -186,7 +195,6 @@ function code_mlir(f, types) cond_br = cf.cond_br(cond, true_args, false_args; trueDest=other_dest, falseDest=dest, location) push!(current_block, cond_br) elseif inst isa ReturnNode - line = ir.linetable[stmt[:line]+1] location = Location(string(line.file), line.line, 0) push!(current_block, func.return_([get_value(inst.val)]; location)) elseif Meta.isexpr(inst, :code_coverage_effect) diff --git a/src/API/API.jl b/src/API/API.jl index 89172a8f..76cd7887 100644 --- a/src/API/API.jl +++ b/src/API/API.jl @@ -18,14 +18,14 @@ end # generate version-less API functions begin - local ops = mapreduce(∪, [v14, v15, v16]) do mod + local ops = mapreduce(∪, [v14, v15, v16, v17]) do mod filter(names(mod; all=true)) do name name ∉ [nameof(mod), :eval, :include] && !startswith(string(name), '#') end end for op in ops - container_mods = filter([v14, v15, v16]) do mod + container_mods = filter([v14, v15, v16, v17]) do mod op in names(mod; all=true) end container_mods = map(container_mods) do mod diff --git a/src/Dialects/Dialects.jl b/src/Dialects/Dialects.jl index 4f37438b..dca5ffcb 100644 --- a/src/Dialects/Dialects.jl +++ b/src/Dialects/Dialects.jl @@ -17,7 +17,7 @@ end begin # list dialect operations - local dialectops = mapreduce(mergewith!(∪), [v14, v15, v16]) do mod + local dialectops = mapreduce(mergewith!(∪), [v14, v15, v16, v17]) do mod dialects = filter(names(mod; all=true)) do dialect dialect ∉ [nameof(mod), :eval, :include] && !startswith(string(dialect), '#') end @@ -33,11 +33,11 @@ begin for (dialect, ops) in dialectops mod = @eval module $dialect using ...MLIR: MLIR_VERSION, MLIRException - using ..Dialects: v14, v15, v16 + using ..Dialects: v14, v15, v16, v17 end for op in ops - container_mods = filter([v14, v15, v16]) do mod + container_mods = filter([v14, v15, v16, v17]) do mod dialect in names(mod; all=true) && op in names(getproperty(mod, dialect); all=true) end diff --git a/src/IR/Pass.jl b/src/IR/Pass.jl index 3ec0527a..82010468 100644 --- a/src/IR/Pass.jl +++ b/src/IR/Pass.jl @@ -64,7 +64,11 @@ end Run the provided `passManager` on the given `module`. """ function run!(pm::PassManager, mod::Module) - status = LogicalResult(API.mlirPassManagerRun(pm, mod)) + status = if MLIR_VERSION[] >= v"17" + LogicalResult(API.mlirPassManagerRunOnOp(pm, Operation(mod))) + else + LogicalResult(API.mlirPassManagerRun(pm, mod)) + end if isfailure(status) throw("failed to run pass manager on module") end diff --git a/test/executionengine.jl b/test/executionengine.jl index f162896f..38c19c62 100644 --- a/test/executionengine.jl +++ b/test/executionengine.jl @@ -24,7 +24,11 @@ function lowerModuleToLLVM(ctx, mod) op = "builtin.func" end opm = MLIR.API.mlirPassManagerGetNestedUnder(pm, op) - if LLVM.version() >= v"15" + if LLVM.version() >= v"17" + MLIR.API.mlirPassManagerAddOwnedPass( + pm, MLIR.API.mlirCreateConversionConvertFuncToLLVMPass() + ) + elseif LLVM.version() >= v"15" MLIR.API.mlirPassManagerAddOwnedPass( pm, MLIR.API.mlirCreateConversionConvertFuncToLLVM() ) @@ -43,7 +47,12 @@ function lowerModuleToLLVM(ctx, mod) opm, MLIR.API.mlirCreateConversionConvertArithmeticToLLVM() ) end - status = MLIR.API.mlirPassManagerRun(pm, mod) + status = if LLVM.version() >= v"17" + op = MLIR.API.mlirModuleGetOperation(mod) + MLIR.API.mlirPassManagerRunOnOp(pm, op) + else + MLIR.API.mlirPassManagerRun(pm, mod) + end # undefined symbol: mlirLogicalResultIsFailure if status.value == 0 error("Unexpected failure running pass failure")