diff --git a/examples/example_unrolling_service/loop_unroller/BUILD b/examples/example_unrolling_service/loop_unroller/BUILD
new file mode 100644
index 000000000..3bec18c35
--- /dev/null
+++ b/examples/example_unrolling_service/loop_unroller/BUILD
@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the LICENSE file
+# in the root directory of this source tree.
+#
+# This package exposes the LLVM optimization pipeline as a CompilerGym service.
+load("@rules_cc//cc:defs.bzl", "cc_binary")
+
+cc_binary(
+ name = "loop_unroller",
+ srcs = [
+ "loop_unroller.cc",
+ ],
+ copts = [
+ "-Wall",
+ "-fdiagnostics-color=always",
+ "-fno-rtti",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@llvm//10.0.0",
+ ],
+)
diff --git a/examples/example_unrolling_service/loop_unroller/README.md b/examples/example_unrolling_service/loop_unroller/README.md
new file mode 100644
index 000000000..5684e6580
--- /dev/null
+++ b/examples/example_unrolling_service/loop_unroller/README.md
@@ -0,0 +1,6 @@
+LLVM's opt does not always enforce the unrolling options passed as cli arguments. Hence, we created our own exeutable with custom unrolling pass in examples/example_unrolling_service/loop_unroller that enforces the unrolling factors passed in its cli.
+
+To run the custom unroller:
+```
+bazel run //examples/example_unrolling_service/loop_unroller:loop_unroller -- .ll --funroll-count= -S -o .ll
+```
diff --git a/examples/example_unrolling_service/loop_unroller/loop_unroller.cc b/examples/example_unrolling_service/loop_unroller/loop_unroller.cc
new file mode 100644
index 000000000..29c87559a
--- /dev/null
+++ b/examples/example_unrolling_service/loop_unroller/loop_unroller.cc
@@ -0,0 +1,201 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+#include
+#include
+#include
+#include
+#include
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Bitcode/BitcodeWriterPass.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IRPrintingPasses.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/SystemUtils.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+
+using namespace llvm;
+
+namespace llvm {
+/// Input LLVM module file name.
+cl::opt InputFilename(cl::Positional, cl::desc("Specify input filename"),
+ cl::value_desc("filename"), cl::init("-"));
+/// Output LLVM module file name.
+cl::opt OutputFilename("o", cl::desc("Specify output filename"),
+ cl::value_desc("filename"), cl::init("-"));
+
+static cl::opt UnrollEnable("floop-unroll", cl::desc("Enable loop unrolling"),
+ cl::init(true));
+
+static cl::opt UnrollCount(
+ "funroll-count", cl::desc("Use this unroll count for all loops including those with "
+ "unroll_count pragma values, for testing purposes"));
+
+// Force binary on terminals
+static cl::opt Force("f", cl::desc("Enable binary output on terminals"));
+
+// Output assembly
+static cl::opt OutputAssembly("S", cl::desc("Write output as LLVM assembly"));
+
+// Preserve use list order
+static cl::opt PreserveBitcodeUseListOrder(
+ "preserve-bc-uselistorder", cl::desc("Preserve use-list order when writing LLVM bitcode."),
+ cl::init(true), cl::Hidden);
+
+static cl::opt PreserveAssemblyUseListOrder(
+ "preserve-ll-uselistorder", cl::desc("Preserve use-list order when writing LLVM assembly."),
+ cl::init(false), cl::Hidden);
+
+// The INITIALIZE_PASS_XXX macros put the initialiser in the llvm namespace.
+void initializeLoopCounterPass(PassRegistry& Registry);
+
+class LoopCounter : public llvm::FunctionPass {
+ public:
+ static char ID;
+ std::unordered_map counts;
+
+ LoopCounter() : FunctionPass(ID) {}
+
+ virtual void getAnalysisUsage(AnalysisUsage& AU) const override {
+ AU.addRequired();
+ }
+
+ bool runOnFunction(llvm::Function& F) override {
+ LoopInfo& LI = getAnalysis().getLoopInfo();
+ auto Loops = LI.getLoopsInPreorder();
+
+ // Should really account for module, too.
+ counts[F.getName().str()] = Loops.size();
+
+ return false;
+ }
+};
+
+// Initialise the pass. We have to declare the dependencies we use.
+char LoopCounter::ID = 0;
+INITIALIZE_PASS_BEGIN(LoopCounter, "count-loops", "Count loops", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(LoopCounter, "count-loops", "Count loops", false, false)
+
+// The INITIALIZE_PASS_XXX macros put the initialiser in the llvm namespace.
+void initializeLoopUnrollConfiguratorPass(PassRegistry& Registry);
+
+class LoopUnrollConfigurator : public llvm::FunctionPass {
+ public:
+ static char ID;
+
+ LoopUnrollConfigurator() : FunctionPass(ID) {}
+
+ virtual void getAnalysisUsage(AnalysisUsage& AU) const override {
+ AU.addRequired();
+ }
+
+ bool runOnFunction(llvm::Function& F) override {
+ LoopInfo& LI = getAnalysis().getLoopInfo();
+ auto Loops = LI.getLoopsInPreorder();
+
+ // Should really account for module, too.
+ for (auto ALoop : Loops) {
+ if (UnrollEnable)
+ addStringMetadataToLoop(ALoop, "llvm.loop.unroll.enable", UnrollEnable);
+ if (UnrollCount)
+ addStringMetadataToLoop(ALoop, "llvm.loop.unroll.count", UnrollCount);
+ }
+
+ return false;
+ }
+};
+
+// Initialise the pass. We have to declare the dependencies we use.
+char LoopUnrollConfigurator::ID = 1;
+INITIALIZE_PASS_BEGIN(LoopUnrollConfigurator, "unroll-loops-configurator",
+ "Configurates loop unrolling", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(LoopUnrollConfigurator, "unroll-loops-configurator",
+ "Configurates loop unrolling", false, false)
+
+/// Reads a module from a file.
+/// On error, messages are written to stderr and null is returned.
+///
+/// \param Context LLVM Context for the module.
+/// \param Name Input file name.
+static std::unique_ptr readModule(LLVMContext& Context, StringRef Name) {
+ SMDiagnostic Diag;
+ std::unique_ptr Module = parseIRFile(Name, Diag, Context);
+
+ if (!Module)
+ Diag.print("llvm-counter", errs());
+
+ return Module;
+}
+
+} // namespace llvm
+
+int main(int argc, char** argv) {
+ cl::ParseCommandLineOptions(argc, argv,
+ " LLVM-Counter\n\n"
+ " Count the loops in a bitcode file.\n");
+
+ LLVMContext Context;
+ SMDiagnostic Err;
+ SourceMgr SM;
+ std::error_code EC;
+
+ std::unique_ptr Module = readModule(Context, InputFilename);
+
+ if (!Module)
+ return 1;
+
+ // Prepare output
+ ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None);
+ if (EC) {
+ Err = SMDiagnostic(OutputFilename, SourceMgr::DK_Error,
+ "Could not open output file: " + EC.message());
+ Err.print(argv[0], errs());
+ return 1;
+ }
+
+ // Run the passes
+ initializeLoopCounterPass(*PassRegistry::getPassRegistry());
+ legacy::PassManager PM;
+ LoopCounter* Counter = new LoopCounter();
+ LoopUnrollConfigurator* UnrollConfigurator = new LoopUnrollConfigurator();
+ PM.add(Counter);
+ PM.add(UnrollConfigurator);
+ PM.add(createLoopUnrollPass());
+ // Passes to output the module
+ if (OutputAssembly) {
+ PM.add(createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder));
+ } else if (Force || !CheckBitcodeOutputToConsole(Out.os())) {
+ PM.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder));
+ }
+ PM.run(*Module);
+
+ // Log loop stats
+ for (auto& x : Counter->counts) {
+ llvm::dbgs() << x.first << ": " << x.second << " loops" << '\n';
+ }
+
+ Out.keep();
+
+ return 0;
+}
diff --git a/examples/example_unrolling_service/service_py/BUILD b/examples/example_unrolling_service/service_py/BUILD
index b8ff53f0a..ec30ddf0f 100644
--- a/examples/example_unrolling_service/service_py/BUILD
+++ b/examples/example_unrolling_service/service_py/BUILD
@@ -7,6 +7,9 @@ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
py_binary(
name = "example-unrolling-service-py",
srcs = ["example_service.py"],
+ data = [
+ "//examples/example_unrolling_service/loop_unroller",
+ ],
main = "example_service.py",
visibility = ["//visibility:public"],
deps = [
diff --git a/examples/example_unrolling_service/service_py/example_service.py b/examples/example_unrolling_service/service_py/example_service.py
index 8a4173827..e115c5de5 100755
--- a/examples/example_unrolling_service/service_py/example_service.py
+++ b/examples/example_unrolling_service/service_py/example_service.py
@@ -98,7 +98,11 @@ class UnrollingCompilationSession(CompilationSession):
]
def __init__(
- self, working_directory: Path, action_space: ActionSpace, benchmark: Benchmark
+ self,
+ working_directory: Path,
+ action_space: ActionSpace,
+ benchmark: Benchmark,
+ use_custom_opt: bool = True,
):
super().__init__(working_directory, action_space, benchmark)
logging.info("Started a compilation session for %s", benchmark.uri)
@@ -110,6 +114,9 @@ def __init__(
self._llc = str(llvm.llc_path())
self._llvm_diff = str(llvm.llvm_diff_path())
self._opt = str(llvm.opt_path())
+ # LLVM's opt does not always enforce the unrolling options passed as cli arguments. Hence, we created our own exeutable with custom unrolling pass in examples/example_unrolling_service/loop_unroller that enforces the unrolling factors passed in its cli.
+ # if self._use_custom_opt is true, use our custom exeutable, otherwise use LLVM's opt
+ self._use_custom_opt = use_custom_opt
# Dump the benchmark source to disk.
self._src_path = str(self.working_dir / "benchmark.c")
@@ -147,28 +154,47 @@ def apply_action(self, action: Action) -> Tuple[bool, Optional[ActionSpace], boo
if choice_index < 0 or choice_index >= num_choices:
raise ValueError("Out-of-range")
- cmd = self._action_space.choice[0].named_discrete_space.value[choice_index]
+ args = self._action_space.choice[0].named_discrete_space.value[choice_index]
logging.info(
"Applying action %d, equivalent command-line arguments: '%s'",
choice_index,
- cmd,
+ args,
)
+ args = args.split()
# make a copy of the LLVM file to compare its contents after applying the action
shutil.copyfile(self._llvm_path, self._llvm_before_path)
# apply action
- run_command(
- [
- self._opt,
- *cmd.split(),
- self._llvm_path,
- "-S",
- "-o",
- self._llvm_path,
- ],
- timeout=30,
- )
+ if self._use_custom_opt:
+ # our custom unroller has an additional `f` at the beginning of each argument
+ for i, arg in enumerate(args):
+ # convert - to -f
+ arg = arg[0] + "f" + arg[1:]
+ args[i] = arg
+ run_command(
+ [
+ "../loop_unroller/loop_unroller",
+ self._llvm_path,
+ *args,
+ "-S",
+ "-o",
+ self._llvm_path,
+ ],
+ timeout=30,
+ )
+ else:
+ run_command(
+ [
+ self._opt,
+ *args,
+ self._llvm_path,
+ "-S",
+ "-o",
+ self._llvm_path,
+ ],
+ timeout=30,
+ )
# compare the IR files to check if the action had an effect
try: