diff --git a/compiler_gym/bin/BUILD b/compiler_gym/bin/BUILD index 3ae0fb971c..1b8f904a68 100644 --- a/compiler_gym/bin/BUILD +++ b/compiler_gym/bin/BUILD @@ -22,7 +22,7 @@ py_binary( srcs = ["datasets.py"], visibility = ["//visibility:public"], deps = [ - "//compiler_gym/datasets:dataset", + ":service", "//compiler_gym/envs", "//compiler_gym/util", "//compiler_gym/util/flags:env_from_flags", @@ -83,6 +83,7 @@ py_binary( srcs = ["service.py"], visibility = ["//visibility:public"], deps = [ + "//compiler_gym/datasets", "//compiler_gym/envs", "//compiler_gym/spaces", "//compiler_gym/util", diff --git a/compiler_gym/bin/datasets.py b/compiler_gym/bin/datasets.py index c6a2495ae5..783e3b6452 100644 --- a/compiler_gym/bin/datasets.py +++ b/compiler_gym/bin/datasets.py @@ -33,23 +33,6 @@ +-------------------+--------------+-----------------+----------------+ These benchmarks are ready for use. Deactivate them using `--deactivate=`. - +---------------------+-----------+-----------------+----------------+ - | Inactive Datasets | License | #. Benchmarks | Size on disk | - +=====================+===========+=================+================+ - | Total | | 0 | 0 Bytes | - +---------------------+-----------+-----------------+----------------+ - These benchmarks may be activated using `--activate=`. - - +------------------------+---------------------------------+-----------------+----------------+ - | Downloadable Dataset | License | #. Benchmarks | Size on disk | - +========================+=================================+=================+================+ - | blas-v0 | BSD 3-Clause | 300 | 4.0 MB | - +------------------------+---------------------------------+-----------------+----------------+ - | polybench-v0 | BSD 3-Clause | 27 | 162.6 kB | - +------------------------+---------------------------------+-----------------+----------------+ - These benchmarks may be installed using `--download= --activate=`. - - Downloading datasets -------------------- @@ -131,23 +114,13 @@ A :code:`--delete_all` flag can be used to delete all of the locally installed datasets. """ -import os import sys -from pathlib import Path -from typing import Tuple -import humanize from absl import app, flags -from compiler_gym.datasets.dataset import ( - LegacyDataset, - activate, - deactivate, - delete, - require, -) +from compiler_gym.bin.service import summarize_datasets +from compiler_gym.datasets.dataset import activate, deactivate, delete, require from compiler_gym.util.flags.env_from_flags import env_from_flags -from compiler_gym.util.tabulate import tabulate flags.DEFINE_list( "download", @@ -175,31 +148,6 @@ FLAGS = flags.FLAGS -def get_count_and_size_of_directory_contents(root: Path) -> Tuple[int, int]: - """Return the number of files and combined size of a directory.""" - count, size = 0, 0 - for root, _, files in os.walk(str(root)): - count += len(files) - size += sum(os.path.getsize(f"{root}/{file}") for file in files) - return count, size - - -def enumerate_directory(name: str, path: Path): - rows = [] - for path in path.iterdir(): - if not path.is_file() or not path.name.endswith(".json"): - continue - dataset = LegacyDataset.from_json_file(path) - rows.append( - (dataset.name, dataset.license, dataset.file_count, dataset.size_bytes) - ) - rows.append(("Total", "", sum(r[2] for r in rows), sum(r[3] for r in rows))) - return tabulate( - [(n, l, humanize.intcomma(f), humanize.naturalsize(s)) for n, l, f, s in rows], - headers=(name, "License", "#. Benchmarks", "Size on disk"), - ) - - def main(argv): """Main entry point.""" if len(argv) != 1: @@ -207,28 +155,20 @@ def main(argv): env = env_from_flags() try: - if not env.datasets_site_path: - raise app.UsageError("Environment has no benchmarks site path") - - env.datasets_site_path.mkdir(parents=True, exist_ok=True) - env.inactive_datasets_site_path.mkdir(parents=True, exist_ok=True) - invalidated_manifest = False for name_or_url in FLAGS.download: require(env, name_or_url) if FLAGS.download_all: - for dataset in env.available_datasets: - require(env, dataset) + for dataset in env.datasets: + dataset.install() for name in FLAGS.activate: activate(env, name) invalidated_manifest = True if FLAGS.activate_all: - for path in env.inactive_datasets_site_path.iterdir(): - activate(env, path.name) invalidated_manifest = True for name in FLAGS.deactivate: @@ -236,8 +176,6 @@ def main(argv): invalidated_manifest = True if FLAGS.deactivate_all: - for path in env.datasets_site_path.iterdir(): - deactivate(env, path.name) invalidated_manifest = True for name in FLAGS.delete: @@ -246,41 +184,10 @@ def main(argv): if invalidated_manifest: env.make_manifest_file() - print(f"{env.spec.id} benchmarks site dir: {env.datasets_site_path}") + print(f"{env.spec.id} benchmarks site dir: {env.datasets.site_data_path}") print() print( - enumerate_directory("Active Datasets", env.datasets_site_path), - ) - print( - "These benchmarks are ready for use. Deactivate them using `--deactivate=`." - ) - print() - print(enumerate_directory("Inactive Datasets", env.inactive_datasets_site_path)) - print("These benchmarks may be activated using `--activate=`.") - print() - print( - tabulate( - sorted( - [ - ( - d.name, - d.license, - humanize.intcomma(d.file_count), - humanize.naturalsize(d.size_bytes), - ) - for d in env.available_datasets.values() - ] - ), - headers=( - "Downloadable Dataset", - "License", - "#. Benchmarks", - "Size on disk", - ), - ) - ) - print( - "These benchmarks may be installed using `--download= --activate=`." + summarize_datasets(env.datasets), ) finally: env.close() diff --git a/compiler_gym/bin/service.py b/compiler_gym/bin/service.py index 9bbde5bc3d..fc525c509a 100644 --- a/compiler_gym/bin/service.py +++ b/compiler_gym/bin/service.py @@ -66,12 +66,17 @@ $ python -m compiler_gym.bin.service --local_service_binary=/path/to/service/binary """ +from typing import Iterable + +import humanize from absl import app, flags +from compiler_gym.datasets import Dataset from compiler_gym.envs import CompilerEnv from compiler_gym.spaces import Commandline from compiler_gym.util.flags.env_from_flags import env_from_flags from compiler_gym.util.tabulate import tabulate +from compiler_gym.util.truncate import truncate flags.DEFINE_integer( "heading_level", @@ -93,18 +98,37 @@ def shape2str(shape, n: int = 80): return f"`{string}`" +def summarize_datasets(datasets: Iterable[Dataset]) -> str: + rows = [] + for dataset in datasets: + # Raw numeric values here, formatted below. + rows.append( + ( + dataset.name, + truncate(dataset.description, max_line_len=60), + dataset.n, + dataset.site_data_size_in_bytes, + ) + ) + rows.append(("Total", "", sum(r[2] for r in rows), sum(r[3] for r in rows))) + return tabulate( + [ + (n, l, humanize.intcomma(f) if f else "∞", humanize.naturalsize(s)) + for n, l, f, s in rows + ], + headers=("Dataset", "Description", "#. Benchmarks", "Size on disk"), + ) + + def print_service_capabilities(env: CompilerEnv, base_heading_level: int = 1): """Discover and print the capabilities of a CompilerGym service. :param env: An environment. """ print(header(f"CompilerGym Service `{env.service}`", base_heading_level).strip()) - print(header("Programs", base_heading_level + 1)) + print(header("Datasets", base_heading_level + 1)) print( - tabulate( - [(p,) for p in sorted(env.benchmarks)], - headers=("Benchmark",), - ) + summarize_datasets(env.datasets), ) print(header("Observation Spaces", base_heading_level + 1)) print( diff --git a/compiler_gym/datasets/BUILD b/compiler_gym/datasets/BUILD index 6bd60d6902..0849953be0 100644 --- a/compiler_gym/datasets/BUILD +++ b/compiler_gym/datasets/BUILD @@ -6,18 +6,17 @@ load("@rules_python//python:defs.bzl", "py_library") py_library( name = "datasets", - srcs = ["__init__.py"], - visibility = ["//visibility:public"], - deps = [ - ":dataset", + srcs = [ + "__init__.py", + "benchmark.py", + "dataset.py", + "datasets.py", + "tar_dataset.py", ], -) - -py_library( - name = "dataset", - srcs = ["dataset.py"], - visibility = ["//compiler_gym:__subpackages__"], + visibility = ["//visibility:public"], deps = [ + "//compiler_gym:validation_result", + "//compiler_gym/service/proto", "//compiler_gym/util", ], ) diff --git a/compiler_gym/datasets/__init__.py b/compiler_gym/datasets/__init__.py index b0dc9440c5..c9a0b69328 100644 --- a/compiler_gym/datasets/__init__.py +++ b/compiler_gym/datasets/__init__.py @@ -3,12 +3,27 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Manage datasets of benchmarks.""" +from compiler_gym.datasets.benchmark import Benchmark from compiler_gym.datasets.dataset import ( + Dataset, LegacyDataset, activate, deactivate, delete, require, ) +from compiler_gym.datasets.datasets import Datasets +from compiler_gym.datasets.tar_dataset import TarDataset, TarDatasetWithManifest -__all__ = ["LegacyDataset", "require", "activate", "deactivate", "delete"] +__all__ = [ + "activate", + "Benchmark", + "Dataset", + "Datasets", + "deactivate", + "delete", + "LegacyDataset", + "require", + "TarDataset", + "TarDatasetWithManifest", +] diff --git a/compiler_gym/datasets/benchmark.py b/compiler_gym/datasets/benchmark.py new file mode 100644 index 0000000000..09df702f73 --- /dev/null +++ b/compiler_gym/datasets/benchmark.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import re +from concurrent.futures import as_completed +from pathlib import Path +from typing import Callable, Iterable, List, Optional + +from compiler_gym.service.proto import Benchmark as BenchmarkProto +from compiler_gym.service.proto import File +from compiler_gym.util import thread_pool +from compiler_gym.validation_result import ValidationError + +# A validation callback is a function that takes a single CompilerEnv instance +# as its argument and returns an iterable sequence of zero or more +# ValidationError tuples. +ValidationCallback = Callable[["CompilerEnv"], Iterable[ValidationError]] # noqa: F821 + + +# Regular expression that matches the full two-part URI prefix of a dataset: +# :// +# +# E.g. "benchmark://foo-v0". +DATASET_NAME_RE = re.compile( + r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))" +) + +# Regular expression that matches the full three-part format of a benchmark URI: +# :/// +# +# E.g. "benchmark://foo-v0/" or "benchmark://foo-v0/program". +BENCHMARK_URI_RE = re.compile( + r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))/(?P[^\s]*)$" +) + + +class Benchmark(object): + """A benchmark program for an environment. + + At a minimum a benchmark is just a program that can be used by a + :class:`CompilerEnv ` as a program to + optimize. Benchmarks may provide additional functionality such as runtime + checks or methods for validating the semantics of a benchmark. The benchmark + for an environment can be set during :meth:`env.reset() + `. The currently active benchmark can + be queried using :attr:`env.benchmark + `: + + >>> env = gym.make("llvm-v0") + >>> env.reset(benchmark="cBench-v1/crc32") + >>> env.benchmark + cBench-v1/crc32 + + A Benchmark instance wraps an instance of the :code:`Benchmark` protocol + buffer from the `RPC interface + `_ + with additional functionality. + + Benchmarks are not normally instantiated directly. Existing benchmarks can + be queried using :meth:`dataset.benchmark()`. Compiler environments may + provide helper functions for generating benchmarks, such as + :meth:`env.make_benchmark() ` for + LLVM. + + The data unlerying a Benchmark instance should be considered immutable. New + attributes cannote be assigned to Benchmark instances. + """ + + __slots__ = ["_proto", "_validation_callbacks"] + + def __init__( + self, + proto: BenchmarkProto, + validation_callbacks: Optional[List[ValidationCallback]] = None, + ): + self._proto = proto + self._validation_callbacks = validation_callbacks or [] + + def __repr__(self) -> str: + return str(self.uri) + + @property + def uri(self) -> str: + """The URI of the benchmark. + + :return: A URI string. + :type: string + """ + return self._proto.uri + + @property + def proto(self) -> BenchmarkProto: + """The protocol buffer representing the benchmark. + + :return: A Benchmark message. + :type: :code:`Benchmark` + """ + return self._proto + + def is_validatable(self) -> bool: + """Whether the benchmark has any validation callbacks + + :return: :code:`True` if the benchmark has at least one validation + callback. + """ + return self._validation_callbacks != [] + + def validate(self, env: "CompilerEnv") -> Iterable[ValidationError]: # noqa: F821 + """Run any validation callbacks. + + Validation callbacks must be thread safe and must not modify the + environment. + + :param env: The :class:`CompilerEnv ` + instance that is being validated. + """ + executor = thread_pool.get_thread_pool_executor() + futures = ( + executor.submit(validator, env) for validator in self.validation_callbacks() + ) + for future in as_completed(futures): + result: Iterable[ValidationError] = future.result() + yield from result + + def validation_callbacks( + self, + ) -> List[ValidationCallback]: + """Run any ad-hoc validation, e.g. difftest, valgrind, etc""" + return self._validation_callbacks + + def add_validation_callbacks( + self, + validation_callbacks: Iterable[ValidationCallback], + ): + """Add a new validation callback.""" + self._validation_callbacks += list(validation_callbacks) + + @classmethod + def from_file(cls, uri: str, path: Path): + """From file. + + :param uri: The URI of the benchmark. + :param path: The filesystem path of the benchmark. + """ + return cls( + proto=BenchmarkProto( + uri=uri, program=File(uri=f"file:///{Path(path).absolute()}") + ), + ) + + @classmethod + def from_file_contents(cls, uri: str, data: bytes): + """From file. + + :param uri: The URI of the benchmark. + :param path: The filesystem path of the benchmark. + """ + return cls(proto=BenchmarkProto(uri=uri, program=File(contents=data))) diff --git a/compiler_gym/datasets/dataset.py b/compiler_gym/datasets/dataset.py index 72608d63ee..f4923d58af 100644 --- a/compiler_gym/datasets/dataset.py +++ b/compiler_gym/datasets/dataset.py @@ -2,21 +2,156 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import io import json -import os +import logging +import re import shutil -import tarfile -import warnings +import subprocess from pathlib import Path -from typing import List, NamedTuple, Optional, Union +from typing import Iterable, List, NamedTuple, Optional -import fasteners +import numpy as np from deprecated.sphinx import deprecated -from compiler_gym.util.download import download +from compiler_gym.datasets.benchmark import Benchmark +from compiler_gym.util.debug_util import get_logging_level +# Regular expression that matches the full three-part format of a benchmark URI: +# :/// +# +# E.g. "benchmark://foo-v0/" or "benchmark://foo-v0/program". +DATASET_NAME_RE = re.compile( + r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))" +) + +BENCHMARK_URI_RE = re.compile( + r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))/(?P[^\s]*)$" +) + + +class Dataset(object): + """TODO.""" + + def __init__( + self, + name: str, + description: str, + license: str, + site_data_base: Path, + long_description_url: Optional[str] = None, + random: Optional[np.random.Generator] = None, + ): + self._name = name + components = DATASET_NAME_RE.match(name) + if not components: + raise ValueError( + f"Invalid dataset name: '{name}'. " + "Dataset name must be in the form: '${protocol}://${name}-v${version}'" + ) + self._description = description + self._license = license + self._protocol = components.group("dataset_protocol") + self._version = int(components.group("dataset_version")) + self._long_description_url = long_description_url + + self.random = random or np.random.default_rng() + self.logger = logging.getLogger("compiler_gym.datasets") + self.logger.setLevel(get_logging_level()) + + # Set up the site data name. + basename = components.group("dataset_name") + self._site_data_path = Path(site_data_base).resolve() / self.protocol / basename + + def __repr__(self): + return self.name + + def seed(self, seed: int): + self.random = np.random.default_rng(seed) + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._description + + @property + def license(self) -> str: + return self._license + + @property + def protocol(self) -> str: + return self._protocol + + @property + def version(self) -> int: + return self._version + + @property + def long_description_url(self) -> str: + return self._long_description_url + + @property + def site_data_path(self) -> Path: + return self._site_data_path + + @property + def site_data_size_in_bytes(self) -> int: + if not self.site_data_path.is_dir(): + return 0 + return int( + subprocess.check_output( + ["du", "-sb", str(self.site_data_path)], universal_newlines=True + ).split()[0] + ) + @property + def n(self) -> int: + """Get the number of benchmarks in the dataset. + + If the number of benchmarks is unbounded, return 0. + """ + return 0 + + def install(self) -> None: + """ + Implementing this method is optional. + """ + pass + + def uninstall(self) -> None: + if self.site_data_path.is_dir(): + shutil.rmtree(self.site_data_path) + + def benchmarks(self) -> Iterable[Benchmark]: + """Possibly lazy list of benchmarks.""" + # Default implementation. Subclasses may which to provide an optimized + # version. + yield from (self.benchmark(uri) for uri in self.benchmark_uris()) + + def benchmark_uris(self) -> Iterable[str]: + """Return an iterator over benchmark URIs that must be consistent + across runs. + + The order of the URIs must be consistent across runs. + """ + raise NotImplementedError("abstract class") + + def benchmark(self, uri: Optional[str] = None) -> Benchmark: + """ + :raise LookupError: If :code:`uri` is provided but does not exist. + """ + raise NotImplementedError("abstract class") + + +@deprecated( + version="0.1.5", + reason=( + "Use the new Dataset class. LegacyDataset will be removed in v0.1.6. " + "`More information `_." + ), +) class LegacyDataset(NamedTuple): """A collection of benchmarks for use by an environment. @@ -108,18 +243,9 @@ def activate(env, name: str) -> bool: already active. :raises ValueError: If there is no dataset with that name. """ - with fasteners.InterProcessLock(env.datasets_site_path / "LOCK"): - if (env.datasets_site_path / name).exists(): - # There is already an active benchmark set with this name. - return False - if not (env.inactive_datasets_site_path / name).exists(): - raise ValueError(f"Inactive dataset not found: {name}") - os.rename(env.inactive_datasets_site_path / name, env.datasets_site_path / name) - os.rename( - env.inactive_datasets_site_path / f"{name}.json", - env.datasets_site_path / f"{name}.json", - ) - return True + del env # Unused. + del name # Unused. + return False # Deprecated. No-op. @deprecated( @@ -136,17 +262,8 @@ def delete(env, name: str) -> bool: :return: :code:`True` if the dataset was deleted, else :code:`False` if already deleted. """ - with fasteners.InterProcessLock(env.datasets_site_path / "LOCK"): - deleted = False - if (env.datasets_site_path / name).exists(): - shutil.rmtree(str(env.datasets_site_path / name)) - os.unlink(str(env.datasets_site_path / f"{name}.json")) - deleted = True - if (env.inactive_datasets_site_path / name).exists(): - shutil.rmtree(str(env.inactive_datasets_site_path / name)) - os.unlink(str(env.inactive_datasets_site_path / f"{name}.json")) - deleted = True - return deleted + env.datasets[name].uninstall() + return True @deprecated( @@ -163,18 +280,12 @@ def deactivate(env, name: str) -> bool: :return: :code:`True` if the dataset was deactivated, else :code:`False` if already inactive. """ - with fasteners.InterProcessLock(env.datasets_site_path / "LOCK"): - if not (env.datasets_site_path / name).exists(): - return False - os.rename(env.datasets_site_path / name, env.inactive_datasets_site_path / name) - os.rename( - env.datasets_site_path / f"{name}.json", - env.inactive_datasets_site_path / f"{name}.json", - ) - return True + del env # Unused. + del name # Unused. + return False # Deprecated. No-op. -def require(env, dataset: Union[str, LegacyDataset]) -> bool: +def require(env, dataset: str) -> bool: """Require that the given dataset is available to the environment. This will download and activate the dataset if it is not already installed. @@ -193,89 +304,5 @@ def require(env, dataset: Union[str, LegacyDataset]) -> bool: :return: :code:`True` if the dataset was downloaded, or :code:`False` if the dataset was already available. """ - - def download_and_unpack_archive( - url: str, sha256: Optional[str] = None - ) -> LegacyDataset: - json_files_before = { - f - for f in env.inactive_datasets_site_path.iterdir() - if f.is_file() and f.name.endswith(".json") - } - tar_data = io.BytesIO(download(url, sha256)) - with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc: - arc.extractall(str(env.inactive_datasets_site_path)) - json_files_after = { - f - for f in env.inactive_datasets_site_path.iterdir() - if f.is_file() and f.name.endswith(".json") - } - new_json = json_files_after - json_files_before - if not len(new_json): - raise OSError(f"Downloaded dataset {url} contains no metadata JSON file") - return LegacyDataset.from_json_file(list(new_json)[0]) - - def unpack_local_archive(path: Path) -> LegacyDataset: - if not path.is_file(): - raise FileNotFoundError(f"File not found: {path}") - json_files_before = { - f - for f in env.inactive_datasets_site_path.iterdir() - if f.is_file() and f.name.endswith(".json") - } - with tarfile.open(str(path), "r:bz2") as arc: - arc.extractall(str(env.inactive_datasets_site_path)) - json_files_after = { - f - for f in env.inactive_datasets_site_path.iterdir() - if f.is_file() and f.name.endswith(".json") - } - new_json = json_files_after - json_files_before - if not len(new_json): - raise OSError(f"Downloaded dataset {url} contains no metadata JSON file") - return LegacyDataset.from_json_file(list(new_json)[0]) - - with fasteners.InterProcessLock(env.datasets_site_path / "LOCK"): - # Resolve the name and URL of the dataset. - sha256 = None - if isinstance(dataset, LegacyDataset): - name, url = dataset.name, dataset.url - elif isinstance(dataset, str): - # Check if we have already downloaded the dataset. - if "://" in dataset: - name, url = None, dataset - dataset: Optional[LegacyDataset] = None - else: - try: - dataset: Optional[LegacyDataset] = env.available_datasets[dataset] - except KeyError: - raise ValueError(f"Dataset not found: {dataset}") - name, url, sha256 = dataset.name, dataset.url, dataset.sha256 - else: - raise TypeError( - f"require() called with unsupported type: {type(dataset).__name__}" - ) - - if dataset and dataset.deprecated: - warnings.warn( - f"Dataset '{dataset.name}' is deprecated as of CompilerGym " - f"release {dataset.deprecated_since}, please update to the " - "latest available version", - DeprecationWarning, - ) - - # Check if we have already downloaded the dataset. - if name: - if (env.datasets_site_path / name).is_dir(): - # Dataset is already downloaded and active. - return False - elif not (env.inactive_datasets_site_path / name).is_dir(): - # Dataset is downloaded but inactive. - name = download_and_unpack_archive(url, sha256=sha256).name - elif url.startswith("file:///"): - name = unpack_local_archive(Path(url[len("file:///") :])).name - else: - name = download_and_unpack_archive(url, sha256=sha256).name - - activate(env, name) - return True + env.datasets.dataset(dataset).install() + return True diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py new file mode 100644 index 0000000000..f9c5b25277 --- /dev/null +++ b/compiler_gym/datasets/datasets.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from pathlib import Path +from typing import Dict, Iterable, Optional, Union + +import numpy as np + +from compiler_gym.datasets.benchmark import Benchmark +from compiler_gym.datasets.dataset import BENCHMARK_URI_RE, Dataset + + +class Datasets(object): + """TODO.""" + + def __init__( + self, + datasets: Iterable[Dataset], + site_data_path: Path, + random: Optional[np.random.Generator] = None, + ): + # A look-up table mapping dataset names to Dataset instances. + self._datasets: Dict[str, Dataset] = {d.name: d for d in datasets} + self._site_data_path = Path(site_data_path) + self._site_data_path.mkdir(exist_ok=True, parents=True) + self.random = random or np.random.default_rng() + + def seed(self, seed: int): + self.random = np.random.default_rng(seed) + + @property + def site_data_path(self) -> Path: + return self._site_data_path + + def datasets(self) -> Iterable[Dataset]: + """ + Iteration order is consistent between runs. + """ + yield from sorted(self._datasets.values(), key=lambda d: d.name) + + def __getitem__(self, name: str) -> Dataset: + return self.dataset(name) + + def __iter__(self) -> Iterable[Dataset]: + yield from self.datasets() + + def dataset(self, name: str) -> Dataset: + if "://" not in name: + name = f"benchmark://{name}" + if name not in self._datasets: + raise LookupError(f"Dataset not found: '{name}'") + return self._datasets[name] + + def add(self, dataset: Dataset) -> None: + if dataset.name in self._datasets: + warnings.warn(f"Overwriting existing dataset '{dataset.name}'") + self._datasets[dataset.name] = dataset + + def remove(self, dataset: Union[str, Dataset]) -> bool: + dataset_name: str = dataset.name if isinstance(dataset, Dataset) else dataset + + if dataset_name in self._datasets: + del self._datasets[dataset_name] + return True + + return False + + def benchmarks(self) -> Iterable[Benchmark]: + for dataset in self.datasets(): + yield from dataset.benchmarks() + + def benchmark_uris(self) -> Iterable[str]: + for dataset in self.datasets(): + yield from dataset.benchmark_uris() + + def benchmark(self, uri: Optional[str] = None) -> Benchmark: + if not self._datasets: + raise ValueError("No datasets available") + + if uri is None: + dataset = self.random.choice(list(self._datasets.values())) + return dataset.benchmark() + + # Prepend the default benchmark:// protocol on URIs + if "://" not in uri: + uri = f"benchmark://{uri}" + + match = BENCHMARK_URI_RE.match(uri) + if not match: + raise ValueError(f"Invalid benchmark URI: '{uri}'") + + dataset_name = match.group("dataset") + if dataset_name not in self._datasets: + raise LookupError(f"Dataset not found: '{dataset_name}'") + + return self._datasets[dataset_name].benchmark(uri) diff --git a/compiler_gym/datasets/tar_dataset.py b/compiler_gym/datasets/tar_dataset.py new file mode 100644 index 0000000000..e23f9b2d7b --- /dev/null +++ b/compiler_gym/datasets/tar_dataset.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import io +import os +import tarfile +from gzip import GzipFile +from pathlib import Path +from threading import Lock +from typing import Iterable, List, Optional + +import fasteners + +from compiler_gym.datasets.dataset import Benchmark, Dataset +from compiler_gym.util.decorators import memoized_property +from compiler_gym.util.download import download + + +class FilesystemDirectoryDataset(Dataset): + def __init__( + self, + dataset_root: Path, + benchmark_file_suffix: str = "", + **dataset_args, + ): + super().__init__(**dataset_args) + self.dataset_root = dataset_root + self.benchmark_file_suffix = benchmark_file_suffix + + @memoized_property + def n(self) -> int: + """Get the number of benchmarks in the dataset. + + If the number of benchmarks is unbounded, return 0. + """ + self.install() + return sum( + [ + sum(1 for f in files if f.endswith(self.benchmark_file_suffix)) + for (_, _, files) in os.walk(self.dataset_root) + ] + ) + + def benchmark_uris(self) -> Iterable[str]: + """Return an iterator over benchmark URIs that must be consistent + across runs. + + The order of the URIs must be consistent across runs. + """ + self.install() + for root, dirs, files in os.walk(self.dataset_root): + dirs.sort() + reldir = root[len(str(self.dataset_root)) + 1 :] + for filename in sorted(files): + if not filename.endswith(self.benchmark_file_suffix): + continue + name_stem = filename[: -len(self.benchmark_file_suffix)] + # Use os.path.join() rather than simple '/' concaentation as + # reldir may be empty. + yield os.path.join(self.name, reldir, name_stem) + + def benchmarks(self) -> Iterable[Benchmark]: + """Possibly lazy list of benchmarks.""" + # Default implementation. Subclasses may which to provide an optimized + # version. + yield from (self.benchmark(uri) for uri in self.benchmark_uris()) + + def benchmark(self, uri: Optional[str] = None) -> Benchmark: + """ + :raise LookupError: If :code:`uri` is provided but does not exist. + """ + self.install() + if uri is None: + return self._get_benchmark_by_index(self.random.integers(self.n)) + + relpath = f"{uri[len(self.name) + 1:]}{self.benchmark_file_suffix}" + abspath = self.dataset_root / relpath + if not abspath.is_file(): + raise LookupError(f"Benchmark not found: {uri}") + return Benchmark.from_file(uri, abspath) + + def _get_benchmark_by_index(self, n: int) -> Benchmark: + i = 0 + for root, dirs, files in os.walk(self.dataset_root): + dirs.sort() + reldir = root[len(str(self.dataset_root)) + 1 :] + for filename in sorted(files): + if not filename.endswith(self.benchmark_file_suffix): + continue + if i == n: + name_stem = filename[: -len(self.benchmark_file_suffix)] + # Use os.path.join() rather than simple '/' concaentation as + # reldir may be empty. + uri = os.path.join(self.name, reldir, name_stem) + return Benchmark.from_file(uri, f"{root}/{filename}") + i += 1 + raise FileNotFoundError(f"Could not find benchmark with index {n} / {self.n}") + + +class TarDataset(Dataset): + + # TODO: Subclass FilesystemDirectoryDataset + + def __init__( + self, + tar_url: str, + benchmark_file_suffix: str = "", + strip_prefix: str = "", + tar_sha256: Optional[str] = None, + **dataset_args, + ): + super().__init__(**dataset_args) + self.tar_url = tar_url + self.tar_sha256 = tar_sha256 + self.benchmark_file_suffix = benchmark_file_suffix + self.strip_prefix = strip_prefix + + self._tar_extracted_marker = self.site_data_path / ".extracted" + self._tar_lock = Lock() + self._tar_lockfile = self.site_data_path / "LOCK" + self._tar_data = self.site_data_path / "contents" / self.strip_prefix + + def install(self) -> None: + """ + Implementing this method is optional. + """ + if self._tar_extracted_marker.is_file(): + return + + self.logger.info("Downloading %s dataset", self.name) + with self._tar_lock: + with fasteners.InterProcessLock(self._tar_lockfile): + tar_data = io.BytesIO(download(self.tar_url, self.tar_sha256)) + with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc: + arc.extractall(str(self.site_data_path / "contents")) + + if self.strip_prefix and not self._tar_data.is_dir(): + raise FileNotFoundError( + f"Directory prefix '{self.strip_prefix}' not found in dataset '{self.name}'" + ) + + self._tar_extracted_marker.touch() + + @memoized_property + def n(self) -> int: + """Get the number of benchmarks in the dataset. + + If the number of benchmarks is unbounded, return 0. + """ + self.install() + return sum( + [ + sum(1 for f in files if f.endswith(self.benchmark_file_suffix)) + for (_, _, files) in os.walk(self._tar_data) + ] + ) + + def benchmarks(self) -> Iterable[Benchmark]: + """Possibly lazy list of benchmarks.""" + # Default implementation. Subclasses may which to provide an optimized + # version. + yield from (self.benchmark(uri) for uri in self.benchmark_uris()) + + def benchmark_uris(self) -> Iterable[str]: + """Return an iterator over benchmark URIs that must be consistent + across runs. + + The order of the URIs must be consistent across runs. + """ + self.install() + for root, dirs, files in os.walk(self._tar_data): + dirs.sort() + reldir = root[len(str(self._tar_data)) + 1 :] + for filename in sorted(files): + if not filename.endswith(self.benchmark_file_suffix): + continue + name_stem = filename[: -len(self.benchmark_file_suffix)] + # Use os.path.join() rather than simple '/' concaentation as + # reldir may be empty. + yield os.path.join(self.name, reldir, name_stem) + + def benchmark(self, uri: Optional[str] = None) -> Benchmark: + """ + :raise LookupError: If :code:`uri` is provided but does not exist. + """ + self.install() + if uri is None: + return self._get_benchmark_by_index(self.random.integers(self.n)) + + relpath = f"{uri[len(self.name) + 1:]}{self.benchmark_file_suffix}" + abspath = self._tar_data / relpath + if not abspath.is_file(): + raise LookupError(f"Benchmark not found: {uri}") + return Benchmark(uri, abspath) + + def _get_benchmark_by_index(self, n: int) -> Benchmark: + i = 0 + for root, dirs, files in os.walk(self._tar_data): + dirs.sort() + reldir = root[len(str(self._tar_data)) + 1 :] + for filename in sorted(files): + if not filename.endswith(self.benchmark_file_suffix): + continue + if i == n: + name_stem = filename[: -len(self.benchmark_file_suffix)] + # Use os.path.join() rather than simple '/' concaentation as + # reldir may be empty. + uri = os.path.join(self.name, reldir, name_stem) + return Benchmark.from_file(uri, f"{root}/{filename}") + i += 1 + raise FileNotFoundError(f"Could not find benchmark with index {n} / {self.n}") + + +class TarDatasetWithManifest(TarDataset): + def __init__(self, manifest_url: str, manifest_sha256: str, **dataet_args): + """TODO. + + :param manifest_url: The URL of a gzip-compressed text file containing a + list of benchmark URIs, one per line. + + :param manifest_sha256: The sha256 checksum of the compressed manifest + file. + """ + super().__init__(**dataet_args) + self.manifest_url = manifest_url + self.manifest_sha256 = manifest_sha256 + self._manifest_path = self.site_data_path / "manifest.txt" + + def _install_manifest(self) -> List[str]: + if self._manifest_path.is_file(): + return + with self._tar_lock: + with fasteners.InterProcessLock(self._tar_lockfile): + self.logger.debug("Downloading %s manifest", self.name) + manifest_data = io.BytesIO( + download(self.manifest_url, self.manifest_sha256) + ) + with GzipFile(fileobj=manifest_data) as gzipf: + manifest = gzipf.read().decode("utf-8").strip() + with open(self._manifest_path, "w") as f: + f.write(manifest) + return manifest.split("\n") + + @memoized_property + def _benchmark_uris(self) -> List[str]: + uris = self._install_manifest() + if not uris: + with open(self._manifest_path) as f: + uris = f.readlines() + self.logger.debug("Read %s manifest, %d entries", self.name, len(uris)) + return uris + + @memoized_property + def n(self) -> int: + """Get the number of benchmarks in the dataset. + + If the number of benchmarks is unbounded, return 0. + """ + return len(self._benchmark_uris) + + def benchmark_uris(self) -> Iterable[str]: + """Return an iterator over benchmark URIs that must be consistent + across runs. + + The order of the URIs must be consistent across runs. + """ + yield from self._benchmark_uris diff --git a/compiler_gym/envs/BUILD b/compiler_gym/envs/BUILD index ba75aca0e1..63519efbe8 100644 --- a/compiler_gym/envs/BUILD +++ b/compiler_gym/envs/BUILD @@ -21,7 +21,7 @@ py_library( deps = [ "//compiler_gym:compiler_env_state", "//compiler_gym:validation_result", - "//compiler_gym/datasets:dataset", + "//compiler_gym/datasets", "//compiler_gym/service", "//compiler_gym/service/proto", "//compiler_gym/spaces", diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index 4d2ebcd5fc..30c5e2b47c 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -5,7 +5,6 @@ """This module defines the OpenAI gym interface for compilers.""" import logging import numbers -import os import sys import warnings from collections.abc import Iterable as IterableType @@ -15,13 +14,14 @@ from time import time from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -import fasteners import gym import numpy as np +from deprecated.sphinx import deprecated from gym.spaces import Space from compiler_gym.compiler_env_state import CompilerEnvState -from compiler_gym.datasets.dataset import LegacyDataset, require +from compiler_gym.datasets.dataset import Dataset, LegacyDataset, require +from compiler_gym.datasets.datasets import Benchmark, Datasets from compiler_gym.service import ( CompilerGymServiceConnection, ConnectionOpts, @@ -30,14 +30,13 @@ ServiceTransportError, observation_t, ) +from compiler_gym.service.proto import AddBenchmarkRequest +from compiler_gym.service.proto import Benchmark as BenchmarkProto from compiler_gym.service.proto import ( - AddBenchmarkRequest, - Benchmark, EndSessionReply, EndSessionRequest, ForkSessionReply, ForkSessionRequest, - GetBenchmarksRequest, GetVersionReply, GetVersionRequest, StartSessionRequest, @@ -45,6 +44,7 @@ ) from compiler_gym.spaces import NamedDiscrete, Reward from compiler_gym.util.debug_util import get_logging_level +from compiler_gym.util.runfiles_path import site_data_path from compiler_gym.util.timer import Timer from compiler_gym.validation_result import ValidationError, ValidationResult from compiler_gym.views import ObservationSpaceSpec, ObservationView, RewardView @@ -128,10 +128,6 @@ class CompilerEnv(gym.Env): Default range is (-inf, +inf). :vartype reward_range: Tuple[float, float] - :ivar datasets_site_path: The filesystem path used by the service - to store benchmarks. - :vartype datasets_site_path: Optional[Path] - :ivar available_datasets: A mapping from dataset name to :class:`LegacyDataset` objects that are available to download. :vartype available_datasets: Dict[str, LegacyDataset] @@ -154,6 +150,7 @@ def __init__( self, service: Union[str, Path], rewards: Optional[List[Reward]] = None, + datasets: Optional[Iterable[Dataset]] = None, benchmark: Optional[Union[str, Benchmark]] = None, observation_space: Optional[Union[str, ObservationSpaceSpec]] = None, reward_space: Optional[Union[str, Reward]] = None, @@ -161,6 +158,7 @@ def __init__( connection_settings: Optional[ConnectionOpts] = None, service_connection: Optional[CompilerGymServiceConnection] = None, logging_level: Optional[int] = None, + datasets_site_path: Optional[Path] = None, ): """Construct and initialize a CompilerGym service environment. @@ -222,7 +220,6 @@ def __init__( self._service_endpoint: Union[str, Path] = service self._connection_settings = connection_settings or ConnectionOpts() - self.datasets_site_path: Optional[Path] = None self.available_datasets: Dict[str, LegacyDataset] = {} self.action_space_name = action_space @@ -232,6 +229,10 @@ def __init__( opts=self._connection_settings, logger=self.logger, ) + self.datasets = Datasets( + datasets=datasets or [], + site_data_path=datasets_site_path or site_data_path("datasets"), + ) # If no reward space is specified, generate some from numeric observation spaces rewards = rewards or [ @@ -247,12 +248,8 @@ def __init__( # The benchmark that is currently being used, and the benchmark that # the user requested. Those do not always correlate, since the user # could request a random benchmark. - self._benchmark_in_use_uri: Optional[str] = None - self._user_specified_benchmark_uri: Optional[str] = None - # A map from benchmark URIs to Benchmark messages. We keep track of any - # user-provided custom benchmarks so that we can register them with a - # reset service. - self._custom_benchmarks: Dict[str, Benchmark] = {} + self._benchmark_in_use: Optional[Benchmark] = None + self._user_specified_benchmark: Optional[Benchmark] = None # Normally when the benchmark is changed the updated value is not # reflected until the next call to reset(). We make an exception for # constructor-time arguments as otherwise the behavior of the benchmark @@ -268,7 +265,7 @@ def __init__( # By forcing the benchmark-in-use URI at constructor time, the first # env.benchmark returns the name as expected. self.benchmark = benchmark - self._benchmark_in_use_uri = self._user_specified_benchmark_uri + self._benchmark_in_use = self._user_specified_benchmark # Process the available action, observation, and reward spaces. self.action_spaces = [ @@ -362,17 +359,6 @@ def state(self) -> CompilerEnvState: commandline=self.commandline(), ) - @property - def inactive_datasets_site_path(self) -> Optional[Path]: - """The filesystem path used to store inactive benchmarks.""" - if self.datasets_site_path: - return ( - self.datasets_site_path.parent - / f"{self.datasets_site_path.name}.inactive" - ) - else: - return None - @property def action_space(self) -> NamedDiscrete: """The current action space. @@ -428,7 +414,7 @@ def benchmark(self) -> Optional[str]: To return to random benchmark selection, set this property to :code:`None`: """ - return self._benchmark_in_use_uri + return self._benchmark_in_use.uri if self._benchmark_in_use else None @benchmark.setter def benchmark(self, benchmark: Optional[Union[str, Benchmark]]): @@ -438,26 +424,18 @@ def benchmark(self, benchmark: Optional[Union[str, Benchmark]]): ) if benchmark is None: self.logger.debug("Unsetting the forced benchmark") - self._user_specified_benchmark_uri = None + self._user_specified_benchmark = None elif isinstance(benchmark, str): + benchmark = self.datasets.benchmark(benchmark) self.logger.debug("Setting benchmark by name: %s", benchmark) - # If the user requested a benchmark by URI, e.g. - # benchmark://cBench-v1/dijkstra, require the dataset (cBench-v1) - # automatically. - if self.datasets_site_path: - components = benchmark.split("://") - if len(components) == 1 or components[0] == "benchmark": - components = components[-1].split("/") - if len(components) > 1: - self.logger.info("Requiring dataset %s", components[0]) - self.require_dataset(components[0]) - self._user_specified_benchmark_uri = benchmark + self._user_specified_benchmark = benchmark elif isinstance(benchmark, Benchmark): - self.logger.debug("Setting benchmark data: %s", benchmark.uri) - self._user_specified_benchmark_uri = benchmark.uri - self._add_custom_benchmarks([benchmark]) + self.logger.debug("Setting benchmark: %s", benchmark) + self._user_specified_benchmark = benchmark else: - raise TypeError(f"Unsupported benchmark type: {type(benchmark).__name__}") + raise TypeError( + f"Expected a Benchmark instance, received: '{type(benchmark).__name__}'" + ) @property def reward_space(self) -> Optional[Reward]: @@ -584,23 +562,13 @@ def fork(self) -> "CompilerEnv": new_env._session_id = reply.session_id # pylint: disable=protected-access new_env.observation.session_id = reply.session_id - # Re-register any custom benchmarks with the new environment. - if self._custom_benchmarks: - new_env._add_custom_benchmarks( # pylint: disable=protected-access - list(self._custom_benchmarks.values()).copy() - ) - # Now that we have initialized the environment with the current state, # set the benchmark so that calls to new_env.reset() will correctly # revert the environment to the initial benchmark state. - new_env._user_specified_benchmark_uri = ( # pylint: disable=protected-access - self.benchmark - ) + new_env._user_specified_benchmark = self._benchmark_in_use # Set the "visible" name of the current benchmark to hide the fact that # we loaded from a custom bitcode file. - new_env._benchmark_in_use_uri = ( # pylint: disable=protected-access - self.benchmark - ) + new_env._benchmark_in_use = self._benchmark_in_use # Create copies of the mutable reward and observation spaces. This # is required to correctly calculate incremental updates. @@ -656,7 +624,7 @@ def __del__(self): def reset( # pylint: disable=arguments-differ self, - benchmark: Optional[Union[str, Benchmark]] = None, + benchmark: Optional[Union[str, BenchmarkProto]] = None, action_space: Optional[str] = None, retry_count: int = 0, ) -> Optional[observation_t]: @@ -684,8 +652,6 @@ def reset( # pylint: disable=arguments-differ self.service = CompilerGymServiceConnection( self._service_endpoint, self._connection_settings ) - # Re-register the custom benchmarks with the new service. - self._add_custom_benchmarks(self._custom_benchmarks.values()) self.action_space_name = action_space or self.action_space_name @@ -697,25 +663,37 @@ def reset( # pylint: disable=arguments-differ ) self._session_id = None - # Update the user requested benchmark, if provided. NOTE: This means - # that env.reset(benchmark=None) does NOT unset a forced benchmark. + # Update the user requested benchmark, if provided, or pick one + # randomly. NOTE: This means that env.reset(benchmark=None) does NOT + # unset a forced benchmark. if benchmark: self.benchmark = benchmark + benchmark: Benchmark = self._user_specified_benchmark + elif self._user_specified_benchmark: + benchmark = self._user_specified_benchmark + else: + benchmark: Benchmark = self.datasets.benchmark() + + self._benchmark_in_use = benchmark + start_session_request = StartSessionRequest( + benchmark=benchmark.uri, + action_space=( + [a.name for a in self.action_spaces].index(self.action_space_name) + if self.action_space_name + else 0 + ), + ) try: - reply = self.service( - self.service.stub.StartSession, - StartSessionRequest( - benchmark=self._user_specified_benchmark_uri, - action_space=( - [a.name for a in self.action_spaces].index( - self.action_space_name - ) - if self.action_space_name - else 0 - ), - ), + reply = self.service(self.service.stub.StartSession, start_session_request) + except FileNotFoundError: + # The benchmark was not found, so try adding it and repeating the + # request. + self.service( + self.service.stub.AddBenchmark, + AddBenchmarkRequest(benchmark=[benchmark.proto]), ) + reply = self.service(self.service.stub.StartSession, start_session_request) except (ServiceError, ServiceTransportError, TimeoutError) as e: # Abort and retry on error. self.logger.warning("%s on reset(): %s", type(e).__name__, e) @@ -727,7 +705,6 @@ def reset( # pylint: disable=arguments-differ retry_count=retry_count + 1, ) - self._benchmark_in_use_uri = reply.benchmark self._session_id = reply.session_id self.observation.session_id = reply.session_id self.reward.get_cost = self.observation.__getitem__ @@ -858,10 +835,9 @@ def render( raise ValueError(f"Invalid mode: {mode}") @property - def benchmarks(self) -> List[str]: - """Enumerate the list of available benchmarks.""" - reply = self.service(self.service.stub.GetBenchmarks, GetBenchmarksRequest()) - return list(reply.benchmark) + def benchmarks(self) -> Iterable[str]: + """Enumerate a (possible unbounded) list of available benchmarks.""" + return self.datasets.benchmark_uris() def _make_action_space(self, name: str, entries: List[str]) -> Space: """Create an action space from the given values. @@ -891,7 +867,8 @@ def _reward_view_type(self): """ return RewardView - def require_datasets(self, datasets: List[Union[str, LegacyDataset]]) -> bool: + @deprecated(version="0.1.5", reason="TODO") + def require_datasets(self, datasets: List[str]) -> bool: """Require that the given datasets are available to the environment. Example usage: @@ -915,20 +892,9 @@ def require_datasets(self, datasets: List[Union[str, LegacyDataset]]) -> bool: dataset_installed = False for dataset in datasets: dataset_installed |= require(self, dataset) - if dataset_installed: - # Signal to the compiler service that the contents of the site data - # directory has changed. - self.logger.debug("Initiating service-side scan of dataset directory") - self.service( - self.service.stub.AddBenchmark, - AddBenchmarkRequest( - benchmark=[Benchmark(uri="service://scan-site-data")] - ), - ) - self.make_manifest_file() return dataset_installed - def require_dataset(self, dataset: Union[str, LegacyDataset]) -> bool: + def require_dataset(self, dataset: str) -> bool: """Require that the given dataset is available to the environment. Alias for @@ -941,30 +907,6 @@ def require_dataset(self, dataset: Union[str, LegacyDataset]) -> bool: """ return self.require_datasets([dataset]) - def make_manifest_file(self) -> Path: - """Create the MANIFEST file. - - :return: The path of the manifest file. - """ - with fasteners.InterProcessLock(self.datasets_site_path / "LOCK"): - manifest_path = ( - self.datasets_site_path.parent - / f"{self.datasets_site_path.name}.MANIFEST" - ) - with open(str(manifest_path), "w") as f: - for root, _, files in os.walk(self.datasets_site_path): - print( - "\n".join( - [ - f"{root[len(str(self.datasets_site_path)) + 1:]}/{f}" - for f in files - if not f.endswith(".json") and f != "LOCK" - ] - ), - file=f, - ) - return manifest_path - def register_dataset(self, dataset: LegacyDataset) -> bool: """Register a new dataset. @@ -992,26 +934,6 @@ def register_dataset(self, dataset: LegacyDataset) -> bool: self.available_datasets[dataset.name] = dataset return True - def _add_custom_benchmarks(self, benchmarks: List[Benchmark]) -> None: - """Register custom benchmarks with the compiler service. - - Benchmark registration occurs automatically using the - :meth:`env.benchmark ` - property, there is usually no need to call this method yourself. - - :param benchmarks: The benchmarks to register. - """ - if not benchmarks: - return - - for benchmark in benchmarks: - self._custom_benchmarks[benchmark.uri] = benchmark - - self.service( - self.service.stub.AddBenchmark, - AddBenchmarkRequest(benchmark=benchmarks), - ) - def apply(self, state: CompilerEnvState) -> None: # noqa """Replay this state on the given an environment. diff --git a/compiler_gym/envs/llvm/BUILD b/compiler_gym/envs/llvm/BUILD index 0522017680..d35354c0f4 100644 --- a/compiler_gym/envs/llvm/BUILD +++ b/compiler_gym/envs/llvm/BUILD @@ -28,6 +28,7 @@ py_library( ], visibility = ["//compiler_gym:__subpackages__"], deps = [ + "//compiler_gym/datasets", "//compiler_gym/service/proto", "//compiler_gym/util", ], @@ -42,7 +43,7 @@ py_library( ], visibility = ["//tests:__subpackages__"], deps = [ - "//compiler_gym/datasets:dataset", + "//compiler_gym/datasets", "//compiler_gym/util", ], ) @@ -57,7 +58,9 @@ py_library( ":benchmarks", ":legacy_datasets", ":llvm_rewards", + "//compiler_gym/datasets", "//compiler_gym/envs:compiler_env", + "//compiler_gym/envs/llvm/datasets", "//compiler_gym/spaces", "//compiler_gym/third_party/autophase", "//compiler_gym/third_party/inst2vec", diff --git a/compiler_gym/envs/llvm/benchmarks.py b/compiler_gym/envs/llvm/benchmarks.py index c1a5283461..e2319cf8fe 100644 --- a/compiler_gym/envs/llvm/benchmarks.py +++ b/compiler_gym/envs/llvm/benchmarks.py @@ -14,7 +14,7 @@ from pathlib import Path from typing import Iterable, List, Optional, Union -from compiler_gym.service.proto import Benchmark, File +from compiler_gym.datasets import Benchmark from compiler_gym.util.runfiles_path import cache_path, runfiles_path CLANG = runfiles_path("compiler_gym/third_party/llvm/bin/clang") @@ -220,7 +220,7 @@ def make_benchmark( :func:`get_system_includes`. :param timeout: The maximum number of seconds to allow clang to run before terminating. - :return: A :code:`Benchmark` message. + :return: A :code:`Benchmark` instance. :raises FileNotFoundError: If any input sources are not found. :raises TypeError: If the inputs are of unsupported types. :raises OSError: If a compilation job fails. @@ -290,9 +290,7 @@ def _add_path(path: Path): # Shortcut if we only have a single pre-compiled bitcode. if len(bitcodes) == 1 and not clang_jobs: bitcode = bitcodes[0] - return Benchmark( - uri=f"file:///{bitcode}", program=File(uri=f"file:///{bitcode}") - ) + return Benchmark.from_file(uri=f"file:///{bitcode}", path=bitcode) tmpdir_root = cache_path(".") tmpdir_root.mkdir(exist_ok=True, parents=True) @@ -335,7 +333,6 @@ def _add_path(path: Path): with open(str(list(bitcodes + clang_outs)[0]), "rb") as f: bitcode = f.read() - timestamp = datetime.now().strftime(f"%Y%m%HT%H%M%S-{random.randrange(16**4):04x}") - return Benchmark( - uri=f"benchmark://user/{timestamp}", program=File(contents=bitcode) - ) + timestamp = datetime.now().strftime("%Y%m%HT%H%M%S") + uri = f"benchmark://user/{timestamp}-{random.randrange(16**4):04x}" + return Benchmark.from_file_contents(uri, bitcode) diff --git a/compiler_gym/envs/llvm/datasets/BUILD b/compiler_gym/envs/llvm/datasets/BUILD new file mode 100644 index 0000000000..2787a0a28b --- /dev/null +++ b/compiler_gym/envs/llvm/datasets/BUILD @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +load("@rules_python//python:defs.bzl", "py_library") + +py_library( + name = "datasets", + srcs = [ + "__init__.py", + "cBench.py", + "llvm_stress.py", + ], + data = [ + "//compiler_gym/third_party/llvm:llvm-as", + "//compiler_gym/third_party/llvm:llvm-stress", + ], + visibility = ["//visibility:public"], + deps = [ + "//compiler_gym/datasets", + "//compiler_gym/service/proto", + "//compiler_gym/util", + ], +) diff --git a/compiler_gym/envs/llvm/datasets/__init__.py b/compiler_gym/envs/llvm/datasets/__init__.py new file mode 100644 index 0000000000..2182dc9105 --- /dev/null +++ b/compiler_gym/envs/llvm/datasets/__init__.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from pathlib import Path +from typing import Iterable + +from compiler_gym.datasets import Dataset, TarDatasetWithManifest +from compiler_gym.envs.llvm.datasets.cBench import CBenchDataset, CBenchLegacyDataset + + +class BlasDataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + super().__init__( + name="benchmark://blas-v0", + tar_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-blas-v0.tar.bz2", + tar_sha256="e724a8114709f8480adeb9873d48e426e8d9444b00cddce48e342b9f0f2b096d", + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-blas-v0-manifest.gz", + manifest_sha256="1d561808bc80e72a33f13b376c10502f1af2645ed6f0fb1851de1b746402db01", + long_description_url="https://github.com/spcl/ncc/tree/master/data", + license="BSD 3-Clause", + strip_prefix="blas-v0", + description="Basic linear algebra kernels", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +class GitHubDataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + super().__init__( + name="benchmark://github-v0", + tar_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-github-v0.tar.bz2", + tar_sha256="880269dd7a5c2508ea222a2e54c318c38c8090eb105c0a87c595e9dd31720764", + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-github-v0-manifest.gz", + manifest_sha256="d7b42ef68c9b452233baa13303d5140d6b9bf15da2ba9d4e7b0ef73524611a42", + license="CC BY 4.0", + long_description_url="https://github.com/ctuning/ctuning-programs", + strip_prefix="github-v0", + description="Compile-only C/C++ objects from GitHub", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +class LinuxDataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + super().__init__( + name="benchmark://linux-v0", + tar_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-linux-v0.tar.bz2", + tar_sha256="a1ae5c376af30ab042c9e54dc432f89ce75f9ebaee953bc19c08aff070f12566", + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-linux-v0-manifest.gz", + manifest_sha256="dbdf82046cb5779fc48f47c4899deea51d396daddee1dd448ccf411010788b96", + long_description_url="https://github.com/spcl/ncc/tree/master/data", + license="GPL-2.0", + strip_prefix="linux-v0", + description="Compile-only object files from C Linux kernel", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +class MibenchDataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + super().__init__( + name="benchmark://mibench-v0", + tar_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-mibench-v0.tar.bz2", + tar_sha256="128c090c40b955b99fdf766da167a5f642018fb35c16a1d082f63be2e977eb13", + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-mibench-v0-manifest.gz", + manifest_sha256="059bc81b92d5942ac0ea74664b8268d1e64f5489dca66992b50b5ef7b527264e", + long_description_url="https://github.com/ctuning/ctuning-programs", + license="BSD 3-Clause", + strip_prefix="mibench-v0", + description="C benchmarks", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +class NPBDataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + super().__init__( + name="benchmark://npb-v0", + tar_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-npb-v0.tar.bz2", + tar_sha256="793ac2e7a4f4ed83709e8a270371e65b724da09eaa0095c52e7f4209f63bb1f2", + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-npb-v0-manifest.gz", + manifest_sha256="edd5cf0863db49cee6551a7cabefc08e931295f3ba1e2990705f05442eb5ebbc", + long_description_url="https://github.com/spcl/ncc/tree/master/data", + license="NASA Open Source Agreement v1.3", + strip_prefix="npb-v0", + description="NASA Parallel Benchmarks", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +class OpenCVDataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + super().__init__( + name="benchmark://opencv-v0", + tar_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-opencv-v0.tar.bz2", + tar_sha256="003df853bd58df93572862ca2f934c7b129db2a3573bcae69a2e59431037205c", + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-opencv-v0-manifest.gz", + manifest_sha256="e5fc1afbfbb978b2e6a5d4d7f3ffed7c612887fbdc1af5f5cba4d0ab29c3ed9b", + long_description_url="https://github.com/spcl/ncc/tree/master/data", + license="Apache 2.0", + strip_prefix="opencv-v0", + description="Compile-only object files from C++ OpenCV library", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +class POJ104Dataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + super().__init__( + name="benchmark://poj104-v0", + tar_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-poj104-v0.tar.bz2", + tar_sha256="6254d629887f6b51efc1177788b0ce37339d5f3456fb8784415ed3b8c25cce27", + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-poj104-v0-manifest.gz", + manifest_sha256="ca68aec704d054a26046bc82aff17938e49f9083078dacc5f042c6051f2d2711", + long_description_url="https://sites.google.com/site/treebasedcnn/", + license="BSD 3-Clause", + strip_prefix="poj104-v0", + description="Solutions to programming programs", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +class TensorflowDataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + super().__init__( + name="benchmark://tensorflow-v0", + tar_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-tensorflow-v0.tar.bz2", + tar_sha256="f77dd1988c772e8359e1303cc9aba0d73d5eb27e0c98415ac3348076ab94efd1", + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-tensorflow-v0-manifest.gz", + manifest_sha256="a78751b4562f27d330e4c20f34b5f1e670fcbe1a92172f0d01c6eba49b182576", + long_description_url="https://github.com/spcl/ncc/tree/master/data", + license="Apache 2.0", + strip_prefix="tensorflow-v0", + description="Compile-only object files from C++ TensorFlow library", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +def get_llvm_datasets(site_data_base: Path) -> Iterable[Dataset]: + yield BlasDataset(site_data_base=site_data_base) + yield CBenchDataset(site_data_base=site_data_base) + yield CBenchLegacyDataset(site_data_base=site_data_base) + yield GitHubDataset(site_data_base=site_data_base) + yield LinuxDataset(site_data_base=site_data_base) + yield MibenchDataset(site_data_base=site_data_base) + yield NPBDataset(site_data_base=site_data_base) + yield OpenCVDataset(site_data_base=site_data_base) + yield POJ104Dataset(site_data_base=site_data_base) + yield TensorflowDataset(site_data_base=site_data_base) + + +__all__ = [ + "BlasDataset", + "CBenchDataset", + "CBenchLegacyDataset", + "GitHubDataset", + "LinuxDataset", + "MibenchDataset", + "NPBDataset", + "OpenCVDataset", + "POJ104Dataset", + "TensorflowDataset", +] diff --git a/compiler_gym/envs/llvm/datasets/cBench.py b/compiler_gym/envs/llvm/datasets/cBench.py new file mode 100644 index 0000000000..966594005f --- /dev/null +++ b/compiler_gym/envs/llvm/datasets/cBench.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import sys +from pathlib import Path + +from deprecated.sphinx import deprecated + +from compiler_gym.datasets import TarDatasetWithManifest + +_CBENCH_TARS = { + "darwin": ( + "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v1-macos.tar.bz2", + "90b312b40317d9ee9ed09b4b57d378879f05e8970bb6de80dc8581ad0e36c84f", + ), + "linux": ( + "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v1-linux.tar.bz2", + "601fff3944c866f6617e653b6eb5c1521382c935f56ca1f36a9f5cf1a49f3de5", + ), +} + + +class CBenchDataset(TarDatasetWithManifest): + def __init__(self, site_data_base: Path): + platform = {"darwin": "macos"}.get(sys.platform, sys.platform) + url, sha256 = _CBENCH_TARS[platform] + super().__init__( + name="benchmark://cBench-v1", + description="Runnable C benchmarks", + license="BSD 3-Clause", + long_description_url="https://github.com/ctuning/ctuning-programs", + tar_url=url, + tar_sha256=sha256, + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v1-manifest.gz", + manifest_sha256="455636dde21013fb593afd47c3dc2d25401a8a8cfff5dde01d2e416f039149ba", + strip_prefix="cBench-v1", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) + + +_CBENCH_LEGACY_TARS = { + "darwin": ( + "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v0-macos.tar.bz2", + "072a730c86144a07bba948c49afe543e4f06351f1cb17f7de77f91d5c1a1b120", + ), + "linux": ( + "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v0-linux.tar.bz2", + "9b5838a90895579aab3b9375e8eeb3ed2ae58e0ad354fec7eb4f8b31ecb4a360", + ), +} + + +@deprecated(version="0.1.4", reason="Please update to cBench-v1") +class CBenchLegacyDataset(TarDatasetWithManifest): + # TODO: Add deprecation notices + def __init__(self, site_data_base: Path): + platform = {"darwin": "macos"}.get(sys.platform, sys.platform) + url, sha256 = _CBENCH_LEGACY_TARS[platform] + super().__init__( + name="benchmark://cBench-v0", + description="Runnable C benchmarks", + license="BSD 3-Clause", + long_description_url="https://github.com/ctuning/ctuning-programs", + tar_url=url, + tar_sha256=sha256, + manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v0-manifest.gz", + manifest_sha256="d3e362b858cb0978a7b974c68a7332ed7b8f2f580cbb083c06313e8fd0009aae", + strip_prefix="cBench-v0", + benchmark_file_suffix=".bc", + site_data_base=site_data_base, + ) diff --git a/compiler_gym/envs/llvm/datasets/llvm_stress.py b/compiler_gym/envs/llvm/datasets/llvm_stress.py new file mode 100644 index 0000000000..c539700222 --- /dev/null +++ b/compiler_gym/envs/llvm/datasets/llvm_stress.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import subprocess +from pathlib import Path +from typing import Iterable, Optional + +from compiler_gym.datasets import Benchmark, Dataset +from compiler_gym.util.runfiles_path import runfiles_path + +LLVM_AS = runfiles_path("compiler_gym/third_party/llvm/bin/llvm-as") +LLVM_STRESS = runfiles_path("compiler_gym/third_party/llvm/bin/llvm-stress") + +# The maximum value for the --seed argument to llvm-stress. +UINT_MAX = (2 << 32) - 1 + + +class LlvmStressDataset(Dataset): + def __init__(self, site_data_base: Path): + super().__init__( + name="generator://llvm-stress-v0", + description="Randomly generated LLVM-IR", + long_description_url="https://llvm.org/docs/CommandGuide/llvm-stress.html", + license="Apache License v2.0 with LLVM Exceptions", + site_data_base=site_data_base, + ) + + def benchmark_uris(self) -> Iterable[str]: + return (f"generator://llvm-stress-v0/{i}" for i in range(UINT_MAX)) + + def benchmark(self, uri: Optional[str] = None): + if uri is None: + seed = self.random.randint(UINT_MAX) + else: + seed = int(uri.split("/")[-1]) + + # Run llvm-stress with the given seed and pipe the output to llvm-as to + # assemble a bitcode. + llvm_stress = subprocess.Popen( + [str(LLVM_STRESS), f"--seed={seed}"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + llvm_as = subprocess.Popen( + [str(LLVM_AS), "-"], + stdin=llvm_stress.stdout, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + stdout, _ = llvm_as.communicate() + if llvm_stress.returncode or llvm_as.returncode: + raise OSError("Failed to generate benchmark") + + return Benchmark.from_file_contents(uri, stdout) diff --git a/compiler_gym/envs/llvm/llvm_env.py b/compiler_gym/envs/llvm/llvm_env.py index bf545b1413..05d3786d4b 100644 --- a/compiler_gym/envs/llvm/llvm_env.py +++ b/compiler_gym/envs/llvm/llvm_env.py @@ -12,10 +12,11 @@ import numpy as np from gym.spaces import Dict as DictSpace +from compiler_gym.datasets import Benchmark from compiler_gym.envs.compiler_env import CompilerEnv -from compiler_gym.envs.llvm.benchmarks import make_benchmark +from compiler_gym.envs.llvm.benchmarks import ClangInvocation, make_benchmark +from compiler_gym.envs.llvm.datasets import get_llvm_datasets from compiler_gym.envs.llvm.legacy_datasets import ( - LLVM_DATASETS, get_llvm_benchmark_validation_callback, ) from compiler_gym.envs.llvm.llvm_rewards import ( @@ -77,7 +78,10 @@ class LlvmEnv(CompilerEnv): :vartype actions: List[int] """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, datasets_site_path: Optional[Path] = None, **kwargs): + datasets_site_path = datasets_site_path or site_data_path( + "llvm/10.0.0/bitcode_benchmarks" + ) super().__init__( *args, **kwargs, @@ -157,14 +161,9 @@ def __init__(self, *args, **kwargs): platform_dependent=True, ), ], + datasets_site_path=datasets_site_path, + datasets=get_llvm_datasets(site_data_base=datasets_site_path), ) - self.datasets_site_path = site_data_path("llvm/10.0.0/bitcode_benchmarks") - - # Register the LLVM datasets. - self.datasets_site_path.mkdir(parents=True, exist_ok=True) - self.inactive_datasets_site_path.mkdir(parents=True, exist_ok=True) - for dataset in LLVM_DATASETS: - self.register_dataset(dataset) self.inst2vec = _INST2VEC_ENCODER @@ -246,10 +245,89 @@ def reset(self, *args, **kwargs): self.require_dataset("cBench-v1") super().reset(*args, **kwargs) - @staticmethod - def make_benchmark(*args, **kwargs): - """Alias to :func:`llvm.make_benchmark() `.""" - return make_benchmark(*args, **kwargs) + def make_benchmark( + self, + inputs: Union[ + str, Path, ClangInvocation, List[Union[str, Path, ClangInvocation]] + ], + copt: Optional[List[str]] = None, + system_includes: bool = True, + timeout: int = 600, + ) -> Benchmark: + """Create a benchmark for use with this environment. + + This function takes one or more inputs and uses them to create a benchmark + that can be passed to :meth:`compiler_gym.envs.LlvmEnv.reset`. + + For single-source C/C++ programs, you can pass the path of the source file: + + >>> benchmark = make_benchmark('my_app.c') + >>> env = gym.make("llvm-v0") + >>> env.reset(benchmark=benchmark) + + The clang invocation used is roughly equivalent to: + + .. code-block:: + + $ clang my_app.c -O0 -c -emit-llvm -o benchmark.bc + + Additional compile-time arguments to clang can be provided using the + :code:`copt` argument: + + >>> benchmark = make_benchmark('/path/to/my_app.cpp', copt=['-O2']) + + If you need more fine-grained control over the options, you can directly + construct a :class:`ClangInvocation ` + to pass a list of arguments to clang: + + >>> benchmark = make_benchmark( + ClangInvocation(['/path/to/my_app.c'], timeout=10) + ) + + For multi-file programs, pass a list of inputs that will be compiled + separately and then linked to a single module: + + >>> benchmark = make_benchmark([ + 'main.c', + 'lib.cpp', + 'lib2.bc', + ]) + + If you already have prepared bitcode files, those can be linked and used + directly: + + >>> benchmark = make_benchmark([ + 'bitcode1.bc', + 'bitcode2.bc', + ]) + + .. note:: + LLVM bitcode compatibility is + `not guaranteed `_, + so you must ensure that any precompiled bitcodes are compatible with the + LLVM version used by CompilerGym, which can be queried using + :func:`LlvmEnv.compiler_version `. + + :param inputs: An input, or list of inputs. + :param copt: A list of command line options to pass to clang when compiling + source files. + :param system_includes: Whether to include the system standard libraries + during compilation jobs. This requires a system toolchain. See + :func:`get_system_includes`. + :param timeout: The maximum number of seconds to allow clang to run before + terminating. + :return: A :code:`Benchmark` instance. + :raises FileNotFoundError: If any input sources are not found. + :raises TypeError: If the inputs are of unsupported types. + :raises OSError: If a compilation job fails. + :raises TimeoutExpired: If a compilation job exceeds :code:`timeout` seconds. + """ + return make_benchmark( + inputs=inputs, + copt=copt, + system_includes=system_includes, + timeout=timeout, + ) def _make_action_space(self, name: str, entries: List[str]) -> Commandline: flags = [ diff --git a/compiler_gym/envs/llvm/service/Benchmark.cc b/compiler_gym/envs/llvm/service/Benchmark.cc index c889eb63de..ccb425bff6 100644 --- a/compiler_gym/envs/llvm/service/Benchmark.cc +++ b/compiler_gym/envs/llvm/service/Benchmark.cc @@ -51,8 +51,10 @@ std::unique_ptr makeModuleOrDie(llvm::LLVMContext& context, const } // anonymous namespace Status readBitcodeFile(const fs::path& path, Bitcode* bitcode) { - std::ifstream ifs; - ifs.open(path.string()); + std::ifstream ifs(path.string()); + if (ifs.fail()) { + return Status(StatusCode::NOT_FOUND, fmt::format("File not found: \"{}\"", path.string())); + } ifs.seekg(0, std::ios::end); if (ifs.fail()) { @@ -93,35 +95,30 @@ std::unique_ptr makeModule(llvm::LLVMContext& context, const Bitco // A benchmark is an LLVM module and the LLVM context that owns it. Benchmark::Benchmark(const std::string& name, const Bitcode& bitcode, - const fs::path& workingDirectory, std::optional bitcodePath, - const BaselineCosts* baselineCosts) + const fs::path& workingDirectory, const BaselineCosts* baselineCosts) : context_(std::make_unique()), module_(makeModuleOrDie(*context_, bitcode, name)), baselineCosts_(baselineCosts ? *baselineCosts : getBaselineCosts(*module_, workingDirectory)), hash_(getModuleHash(*module_)), name_(name), - bitcodeSize_(bitcode.size()), - bitcodePath_(bitcodePath) {} + bitcodeSize_(bitcode.size()) {} Benchmark::Benchmark(const std::string& name, std::unique_ptr context, std::unique_ptr module, size_t bitcodeSize, - const fs::path& workingDirectory, std::optional bitcodePath, - const BaselineCosts* baselineCosts) + const fs::path& workingDirectory, const BaselineCosts* baselineCosts) : context_(std::move(context)), module_(std::move(module)), baselineCosts_(baselineCosts ? *baselineCosts : getBaselineCosts(*module_, workingDirectory)), hash_(getModuleHash(*module_)), name_(name), - bitcodeSize_(bitcodeSize), - bitcodePath_(bitcodePath) {} + bitcodeSize_(bitcodeSize) {} std::unique_ptr Benchmark::clone(const fs::path& workingDirectory) const { Bitcode bitcode; llvm::raw_svector_ostream ostream(bitcode); llvm::WriteBitcodeToFile(module(), ostream); - return std::make_unique(name(), bitcode, workingDirectory, bitcodePath(), - &baselineCosts()); + return std::make_unique(name(), bitcode, workingDirectory, &baselineCosts()); } } // namespace compiler_gym::llvm_service diff --git a/compiler_gym/envs/llvm/service/Benchmark.h b/compiler_gym/envs/llvm/service/Benchmark.h index 377063e8f9..c3f9e95f7c 100644 --- a/compiler_gym/envs/llvm/service/Benchmark.h +++ b/compiler_gym/envs/llvm/service/Benchmark.h @@ -38,13 +38,11 @@ class Benchmark { public: Benchmark(const std::string& name, const Bitcode& bitcode, const boost::filesystem::path& workingDirectory, - std::optional bitcodePath = std::nullopt, const BaselineCosts* baselineCosts = nullptr); Benchmark(const std::string& name, std::unique_ptr context, std::unique_ptr module, size_t bitcodeSize, const boost::filesystem::path& workingDirectory, - std::optional bitcodePath = std::nullopt, const BaselineCosts* baselineCosts = nullptr); // Make a copy of the benchmark. @@ -52,8 +50,6 @@ class Benchmark { inline const std::string& name() const { return name_; } - inline const std::optional bitcodePath() const { return bitcodePath_; } - inline const size_t bitcodeSize() const { return bitcodeSize_; } inline llvm::Module& module() { return *module_; } @@ -90,9 +86,6 @@ class Benchmark { const std::string name_; // The length of the bitcode string for this benchmark. const size_t bitcodeSize_; - // The path of the bitcode file for this benchmark. This is optional - - // benchmarks do not have to be backed by a file. - const std::optional bitcodePath_; }; } // namespace compiler_gym::llvm_service diff --git a/compiler_gym/envs/llvm/service/BenchmarkFactory.cc b/compiler_gym/envs/llvm/service/BenchmarkFactory.cc index 98d9ed4dbe..45b6f8a7e9 100644 --- a/compiler_gym/envs/llvm/service/BenchmarkFactory.cc +++ b/compiler_gym/envs/llvm/service/BenchmarkFactory.cc @@ -35,18 +35,22 @@ BenchmarkFactory::BenchmarkFactory(const boost::filesystem::path& workingDirecto rand_(rand.has_value() ? *rand : std::mt19937_64(std::random_device()())), loadedBenchmarksSize_(0), maxLoadedBenchmarkSize_(maxLoadedBenchmarkSize) { - // Register all benchmarks from the site data directory. - if (fs::is_directory(kSiteBenchmarksDir)) { - CRASH_IF_ERROR(scanSiteDataDirectory()); - } else { - LOG(INFO) << "LLVM site benchmark directory not found: " << kSiteBenchmarksDir.string(); + VLOG(2) << "BenchmarkFactory initialized"; +} + +Status BenchmarkFactory::getBenchmark(const std::string& uri, + std::unique_ptr* benchmark) { + // Check if the benchmark has already been loaded into memory. + auto loaded = benchmarks_.find(uri); + if (loaded != benchmarks_.end()) { + *benchmark = loaded->second.clone(workingDirectory_); + return Status::OK; } - VLOG(2) << "BenchmarkFactory initialized with " << numBenchmarks() << " benchmarks"; + return Status(StatusCode::NOT_FOUND, "Benchmark not found"); } -Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitcode, - std::optional bitcodePath) { +Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitcode) { Status status; std::unique_ptr context = std::make_unique(); std::unique_ptr module = makeModule(*context, bitcode, uri, &status); @@ -59,34 +63,17 @@ Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitco << " exceeds maximum in-memory cache capacity " << maxLoadedBenchmarkSize_ << ", " << benchmarks_.size() << " bitcodes"; int evicted = 0; - // Evict benchmarks until we have reduced capacity below 50%. Use a - // bounded for loop to prevent infinite loop if we get "unlucky" and - // have no valid candidates to unload. + // Evict benchmarks until we have reduced capacity below 50%. const size_t targetCapacity = maxLoadedBenchmarkSize_ / 2; - for (size_t i = 0; i < benchmarks_.size() * 2; ++i) { - // We have run out of benchmarks to evict, or have freed up - // enough capacity. - if (!benchmarks_.size() || loadedBenchmarksSize_ < targetCapacity) { - break; - } - + while (benchmarks_.size() && loadedBenchmarksSize_ > targetCapacity) { // Select a cached benchmark randomly. std::uniform_int_distribution distribution(0, benchmarks_.size() - 1); size_t index = distribution(rand_); auto iterator = std::next(std::begin(benchmarks_), index); - // Check that the benchmark has an on-disk bitcode file which - // can be loaded to re-cache this bitcode. If not, we cannot - // evict it. - if (!iterator->second.bitcodePath().has_value()) { - continue; - } - - // Evict the benchmark: add it to the pool of unloaded benchmarks and - // delete it from the pool of loaded benchmarks. + // Evict the benchmark from the pool of loaded benchmarks. ++evicted; loadedBenchmarksSize_ -= iterator->second.bitcodeSize(); - unloadedBitcodePaths_.insert({iterator->first, *iterator->second.bitcodePath()}); benchmarks_.erase(iterator); } @@ -94,239 +81,19 @@ Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitco << loadedBenchmarksSize_ << ", " << benchmarks_.size() << " bitcodes"; } - benchmarks_.insert({uri, Benchmark(uri, std::move(context), std::move(module), bitcodeSize, - workingDirectory_, bitcodePath)}); + benchmarks_.insert( + {uri, Benchmark(uri, std::move(context), std::move(module), bitcodeSize, workingDirectory_)}); loadedBenchmarksSize_ += bitcodeSize; return Status::OK; } -Status BenchmarkFactory::addBitcodeFile(const std::string& uri, - const boost::filesystem::path& path) { - if (!fs::exists(path)) { - return Status(StatusCode::NOT_FOUND, fmt::format("File not found: \"{}\"", path.string())); - } - unloadedBitcodePaths_[uri] = path; - return Status::OK; -} - -Status BenchmarkFactory::addBitcodeUriAlias(const std::string& src, const std::string& dst) { - // TODO(github.com/facebookresearch/CompilerGym/issues/2): Add support - // for additional protocols, e.g. http://. - if (dst.rfind("file:////", 0) != 0) { - return Status(StatusCode::INVALID_ARGUMENT, - fmt::format("Unsupported benchmark URI protocol: \"{}\"", dst)); - } - - // Resolve path from file:/// protocol URI. - const boost::filesystem::path path{dst.substr(util::strLen("file:///"))}; - return addBitcodeFile(src, path); -} - -namespace { - -bool endsWith(const std::string& str, const std::string& suffix) { - return str.size() >= suffix.size() && - str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; -} - -} // anonymous namespace - -Status BenchmarkFactory::addDirectoryOfBitcodes(const boost::filesystem::path& root) { - VLOG(3) << "addDirectoryOfBitcodes(" << root.string() << ")"; - if (!fs::is_directory(root)) { - return Status(StatusCode::INVALID_ARGUMENT, - fmt::format("Directory not found: \"{}\"", root.string())); - } - - // Check if there is a manifest file that we can read, rather than having to - // enumerate the directory ourselves. - const auto manifestPath = fs::path(root.string() + ".MANIFEST"); - if (fs::is_regular_file(manifestPath)) { - VLOG(3) << "Reading manifest file: " << manifestPath; - return addDirectoryOfBitcodes(root, manifestPath); - } - - const auto rootPathSize = root.string().size(); - for (auto it : fs::recursive_directory_iterator(root, fs::symlink_option::recurse)) { - if (!fs::is_regular_file(it)) { - continue; - } - const std::string& path = it.path().string(); - - if (!endsWith(path, kExpectedExtension)) { - continue; - } - - // The name of the benchmark is path, relative to the root, without the - // file extension. - const std::string name = - path.substr(rootPathSize + 1, path.size() - rootPathSize - kExpectedExtension.size() - 1); - const std::string uri = fmt::format("benchmark://{}", name); - - RETURN_IF_ERROR(addBitcodeFile(uri, path)); - } - - return Status::OK; -} - -Status BenchmarkFactory::addDirectoryOfBitcodes(const boost::filesystem::path& root, - const boost::filesystem::path& manifestPath) { - std::ifstream infile(manifestPath.string()); - std::string relPath; - while (std::getline(infile, relPath)) { - if (!endsWith(relPath, kExpectedExtension)) { - continue; - } - - const fs::path path = root / relPath; - const std::string name = relPath.substr(0, relPath.size() - kExpectedExtension.size()); - const std::string uri = fmt::format("benchmark://{}", name); - - RETURN_IF_ERROR(addBitcodeFile(uri, path)); - } - - return Status::OK; -} - -Status BenchmarkFactory::getBenchmark(std::unique_ptr* benchmark) { - if (!benchmarks_.size() && !unloadedBitcodePaths_.size()) { - return Status(StatusCode::NOT_FOUND, - fmt::format("No benchmarks registered. Site data directory: `{}`", - kSiteBenchmarksDir.string())); - } - - const size_t unloadedBenchmarkCount = unloadedBitcodePaths_.size(); - const size_t loadedBenchmarkCount = benchmarks_.size(); - - const size_t benchmarkCount = unloadedBenchmarkCount + loadedBenchmarkCount; - - std::uniform_int_distribution distribution(0, benchmarkCount - 1); - size_t index = distribution(rand_); - - if (index < unloadedBenchmarkCount) { - // Select a random unloaded benchmark to load and move to the loaded - // benchmark collection. - auto unloadedBenchmark = std::next(std::begin(unloadedBitcodePaths_), index); - CHECK(unloadedBenchmark != unloadedBitcodePaths_.end()); - RETURN_IF_ERROR(loadBenchmark(unloadedBenchmark, benchmark)); - } else { - auto loadedBenchmark = std::next(std::begin(benchmarks_), index - unloadedBenchmarkCount); - CHECK(loadedBenchmark != benchmarks_.end()); - *benchmark = loadedBenchmark->second.clone(workingDirectory_); - } - - return Status::OK; -} - -Status BenchmarkFactory::getBenchmark(const std::string& uri, - std::unique_ptr* benchmark) { - std::string resolvedUri = uri; - // Prepend benchmark:// protocol if not specified. E.g. "foo/bar" resolves to - // "benchmark://foo/bar", but "file:///foo/bar" is not changed. - if (uri.find("://") == std::string::npos) { - resolvedUri = fmt::format("benchmark://{}", uri); - } - - auto loaded = benchmarks_.find(resolvedUri); - if (loaded != benchmarks_.end()) { - *benchmark = loaded->second.clone(workingDirectory_); - return Status::OK; - } - - auto unloaded = unloadedBitcodePaths_.find(resolvedUri); - if (unloaded != unloadedBitcodePaths_.end()) { - RETURN_IF_ERROR(loadBenchmark(unloaded, benchmark)); - return Status::OK; - } - - // No exact name match - attempt to match the URI prefix. - return getBenchmarkByUriPrefix(uri, resolvedUri, benchmark); -} - -Status BenchmarkFactory::getBenchmarkByUriPrefix(const std::string& uriPrefix, - const std::string& resolvedUriPrefix, - std::unique_ptr* benchmark) { - // Make a list of all of the known benchmarks which match this prefix. - std::vector candidateBenchmarks; - for (const auto& it : unloadedBitcodePaths_) { - const std::string& uri = it.first; - if (uri.rfind(resolvedUriPrefix, 0) == 0) { - candidateBenchmarks.push_back(uri.c_str()); - } - } - for (const auto& it : benchmarks_) { - const std::string& uri = it.first; - if (uri.rfind(resolvedUriPrefix, 0) == 0) { - candidateBenchmarks.push_back(uri.c_str()); - } - } - - const size_t candidatesCount = candidateBenchmarks.size(); - if (!candidatesCount) { - return Status(StatusCode::INVALID_ARGUMENT, fmt::format("Unknown benchmark \"{}\"", uriPrefix)); - } - - // Select randomly from the list of candidates. - std::uniform_int_distribution distribution(0, candidatesCount - 1); - size_t index = distribution(rand_); - return getBenchmark(candidateBenchmarks[index], benchmark); -} - -std::vector BenchmarkFactory::getBenchmarkNames() const { - std::vector names; - names.reserve(unloadedBitcodePaths_.size() + benchmarks_.size()); - for (const auto& it : unloadedBitcodePaths_) { - names.push_back(it.first); - } - for (const auto& it : benchmarks_) { - names.push_back(it.first); - } - return names; -} - -Status BenchmarkFactory::loadBenchmark( - std::unordered_map::const_iterator iterator, - std::unique_ptr* benchmark) { - const std::string uri = iterator->first; - const char* path = iterator->second.string().c_str(); - - VLOG(2) << "loadBenchmark(" << path << ")"; +Status BenchmarkFactory::addBitcode(const std::string& uri, const fs::path& path) { + VLOG(2) << "addBitcode(" << path.string() << ")"; Bitcode bitcode; - std::ifstream ifs; - ifs.open(path); - - ifs.seekg(0, std::ios::end); - if (ifs.fail()) { - return Status(StatusCode::NOT_FOUND, fmt::format("Error reading file: \"{}\"", path)); - } - - std::streampos fileSize = ifs.tellg(); - if (!fileSize) { - return Status(StatusCode::INVALID_ARGUMENT, fmt::format("File is empty: \"{}\"", path)); - } - - bitcode.resize(fileSize); - ifs.seekg(0); - ifs.read(&bitcode[0], bitcode.size()); - if (ifs.fail()) { - return Status(StatusCode::NOT_FOUND, fmt::format("Error reading file: \"{}\"", path)); - } - - RETURN_IF_ERROR(addBitcode(uri, bitcode, fs::path(path))); - - unloadedBitcodePaths_.erase(iterator); - *benchmark = benchmarks_.find(uri)->second.clone(workingDirectory_); - return Status::OK; -} - -size_t BenchmarkFactory::numBenchmarks() const { - return benchmarks_.size() + unloadedBitcodePaths_.size(); -} - -Status BenchmarkFactory::scanSiteDataDirectory() { - return addDirectoryOfBitcodes(kSiteBenchmarksDir); + RETURN_IF_ERROR(readBitcodeFile(path, &bitcode)); + return addBitcode(uri, bitcode); } } // namespace compiler_gym::llvm_service diff --git a/compiler_gym/envs/llvm/service/BenchmarkFactory.h b/compiler_gym/envs/llvm/service/BenchmarkFactory.h index 11ecf7148d..8aef66dbdf 100644 --- a/compiler_gym/envs/llvm/service/BenchmarkFactory.h +++ b/compiler_gym/envs/llvm/service/BenchmarkFactory.h @@ -31,10 +31,8 @@ constexpr size_t kMaxLoadedBenchmarkSize = 512 * 1024 * 1024; // sessions. Example usage: // // BenchmarkFactory factory; -// for (int i = 0; i < 10; ++i) { -// auto benchmark = factory.getBenchmark(); -// // ... do fun stuff -// } +// auto benchmark = factory.getBenchmark("file:////tmp/my_bitcode.bc"); +// // ... do fun stuff class BenchmarkFactory { public: // Construct a benchmark factory. rand is a random seed used to control the @@ -46,69 +44,16 @@ class BenchmarkFactory { std::optional rand = std::nullopt, size_t maxLoadedBenchmarkSize = kMaxLoadedBenchmarkSize); - // Add a new bitcode. bitcodePath is optional. If provided, it allows the - // newly added benchmark to be evicted from the in-memory cache. - [[nodiscard]] grpc::Status addBitcode( - const std::string& uri, const Bitcode& bitcode, - std::optional bitcodePath = std::nullopt); - - // Add a bitcode URI alias. For example, - // addBitcodeFile("benchmark://foo", "file:///tmp/foo.bc") - // adds a new benchmark "benchmark://foo" which resolves to the path - // "/tmp/foo.bc". - [[nodiscard]] grpc::Status addBitcodeUriAlias(const std::string& src, const std::string& dst); - - // Add a directory of bitcode files. The format for added benchmark URIs is - // `benchmark://`, where relStem is the path of the file, relative - // to the root of the directory, without the file extension. - // - // Note that if any of the bitcodes are invalid, this error will be latent - // until a call to getBenchmark() attempts to load it. - [[nodiscard]] grpc::Status addDirectoryOfBitcodes(const boost::filesystem::path& path); - - // Get a random benchmark. - [[nodiscard]] grpc::Status getBenchmark(std::unique_ptr* benchmark); - // Get the requested named benchmark. [[nodiscard]] grpc::Status getBenchmark(const std::string& uri, std::unique_ptr* benchmark); - // Enumerate the list of available benchmark names that can be - // passed to getBenchmark(). - [[nodiscard]] std::vector getBenchmarkNames() const; - - // Scan the site data directory for new files. This is used to indicate that - // the directory has changed. - [[nodiscard]] grpc::Status scanSiteDataDirectory(); + [[nodiscard]] grpc::Status addBitcode(const std::string& uri, const Bitcode& bitcode); - size_t numBenchmarks() const; - - // Register the path of a new bitcode file using the given URI. If the URI - // already exists, it is replaced. - [[nodiscard]] grpc::Status addBitcodeFile(const std::string& uri, - const boost::filesystem::path& path); + [[nodiscard]] grpc::Status addBitcode(const std::string& uri, + const boost::filesystem::path& path); private: - // Add a directory of bitcode files by reading a MANIFEST file. The manifest - // file must consist of a single relative path per line. - [[nodiscard]] grpc::Status addDirectoryOfBitcodes(const boost::filesystem::path& path, - const boost::filesystem::path& manifestPath); - - // Fetch a random benchmark matching a given URI prefix. - [[nodiscard]] grpc::Status getBenchmarkByUriPrefix(const std::string& uriPrefix, - const std::string& resolvedUriPrefix, - std::unique_ptr* benchmark); - - [[nodiscard]] grpc::Status loadBenchmark( - std::unordered_map::const_iterator iterator, - std::unique_ptr* benchmark); - - // A map from benchmark name to the path of a bitcode file. This is used to - // store the paths of benchmarks w - // hich have not yet been loaded into memory. - // Once loaded, they are removed from this map and replaced by an entry in - // benchmarks_. - std::unordered_map unloadedBitcodePaths_; // A mapping from URI to benchmarks which have been loaded into memory. std::unordered_map benchmarks_; diff --git a/compiler_gym/envs/llvm/service/LlvmService.cc b/compiler_gym/envs/llvm/service/LlvmService.cc index 068ba5ef35..4ad7d79973 100644 --- a/compiler_gym/envs/llvm/service/LlvmService.cc +++ b/compiler_gym/envs/llvm/service/LlvmService.cc @@ -14,6 +14,7 @@ #include "compiler_gym/service/proto/compiler_gym_service.pb.h" #include "compiler_gym/util/EnumUtil.h" #include "compiler_gym/util/GrpcStatusMacros.h" +#include "compiler_gym/util/StrLenConstexpr.h" #include "compiler_gym/util/Version.h" #include "llvm/ADT/Triple.h" #include "llvm/Config/llvm-config.h" @@ -53,13 +54,13 @@ Status LlvmService::StartSession(ServerContext* /* unused */, const StartSession StartSessionReply* reply) { const std::lock_guard lock(sessionsMutex_); - std::unique_ptr benchmark; - if (request->benchmark().size()) { - RETURN_IF_ERROR(benchmarkFactory_.getBenchmark(request->benchmark(), &benchmark)); - } else { - RETURN_IF_ERROR(benchmarkFactory_.getBenchmark(&benchmark)); + if (!request->benchmark().size()) { + return Status(StatusCode::INVALID_ARGUMENT, "No benchmark URI set for StartSession()"); } + std::unique_ptr benchmark; + RETURN_IF_ERROR(benchmarkFactory_.getBenchmark(request->benchmark(), &benchmark)); + reply->set_benchmark(benchmark->name()); VLOG(1) << "StartSession(" << benchmark->name() << "), [" << nextSessionId_ << "]"; @@ -138,19 +139,23 @@ Status LlvmService::addBenchmark(const ::compiler_gym::Benchmark& request) { return Status(StatusCode::INVALID_ARGUMENT, "Benchmark must have a URI"); } - if (uri == "service://scan-site-data") { - return benchmarkFactory_.scanSiteDataDirectory(); - } - const auto& programFile = request.program(); switch (programFile.data_case()) { case ::compiler_gym::File::DataCase::kContents: - RETURN_IF_ERROR(benchmarkFactory_.addBitcode( - uri, llvm::SmallString<0>(programFile.contents().begin(), programFile.contents().end()))); - break; - case ::compiler_gym::File::DataCase::kUri: - RETURN_IF_ERROR(benchmarkFactory_.addBitcodeUriAlias(uri, programFile.uri())); - break; + return benchmarkFactory_.addBitcode( + uri, llvm::SmallString<0>(programFile.contents().begin(), programFile.contents().end())); + case ::compiler_gym::File::DataCase::kUri: { + // Check that protocol of the benmchmark URI. + if (programFile.uri().find("file:///") != 0) { + return Status(StatusCode::INVALID_ARGUMENT, + fmt::format("Invalid benchmark data URI. " + "Only the file:/// protocol is supported: \"{}\"", + programFile.uri())); + } + + const fs::path path(programFile.uri().substr(util::strLen("file:///"), std::string::npos)); + return benchmarkFactory_.addBitcode(uri, path); + } case ::compiler_gym::File::DataCase::DATA_NOT_SET: return Status(StatusCode::INVALID_ARGUMENT, "No program set"); } @@ -158,16 +163,6 @@ Status LlvmService::addBenchmark(const ::compiler_gym::Benchmark& request) { return Status::OK; } -Status LlvmService::GetBenchmarks(ServerContext* /* unused */, - const GetBenchmarksRequest* /* unused */, - GetBenchmarksReply* reply) { - for (const auto& benchmark : benchmarkFactory_.getBenchmarkNames()) { - reply->add_benchmark(benchmark); - } - - return Status::OK; -} - Status LlvmService::session(uint64_t id, LlvmSession** environment) { auto it = sessions_.find(id); if (it == sessions_.end()) { diff --git a/compiler_gym/envs/llvm/service/LlvmService.h b/compiler_gym/envs/llvm/service/LlvmService.h index 3f85a438b9..4e321e1f91 100644 --- a/compiler_gym/envs/llvm/service/LlvmService.h +++ b/compiler_gym/envs/llvm/service/LlvmService.h @@ -48,9 +48,6 @@ class LlvmService final : public CompilerGymService::Service { grpc::Status AddBenchmark(grpc::ServerContext* context, const AddBenchmarkRequest* request, AddBenchmarkReply* reply) final override; - grpc::Status GetBenchmarks(grpc::ServerContext* context, const GetBenchmarksRequest* request, - GetBenchmarksReply* reply) final override; - protected: grpc::Status session(uint64_t id, LlvmSession** environment); grpc::Status session(uint64_t id, const LlvmSession** environment) const; diff --git a/compiler_gym/service/proto/compiler_gym_service.proto b/compiler_gym/service/proto/compiler_gym_service.proto index ef126e0760..a60fb97b6e 100644 --- a/compiler_gym/service/proto/compiler_gym_service.proto +++ b/compiler_gym/service/proto/compiler_gym_service.proto @@ -23,7 +23,8 @@ service CompilerGymService { rpc GetSpaces(GetSpacesRequest) returns (GetSpacesReply); // Start a new CompilerGym service session. This allocates a new session on // the service and returns a session ID. To terminate the session, call - // EndSession() once done. + // EndSession() once done. Raises grpc::StatusCode::NOT_FOUND if the requested + // benchmark URI is not found. rpc StartSession(StartSessionRequest) returns (StartSessionReply); // Fork a session. This creates a new session in exactly the same state. The // new session must be terminated with EndSession() once done. This returns diff --git a/compiler_gym/third_party/cBench/make_llvm_module.py b/compiler_gym/third_party/cBench/make_llvm_module.py index fccdda3b63..3864219c75 100644 --- a/compiler_gym/third_party/cBench/make_llvm_module.py +++ b/compiler_gym/third_party/cBench/make_llvm_module.py @@ -32,7 +32,7 @@ def make_cbench_llvm_module( benchmark = make_benchmark(inputs=src_files, copt=cflags or []) # Write just the bitcode to file. with open(output_path, "wb") as f: - f.write(benchmark.program.contents) + f.write(benchmark.proto.program.contents) def main(): diff --git a/compiler_gym/third_party/llvm/BUILD b/compiler_gym/third_party/llvm/BUILD index d45b744c4a..d1458c2979 100644 --- a/compiler_gym/third_party/llvm/BUILD +++ b/compiler_gym/third_party/llvm/BUILD @@ -25,6 +25,12 @@ filegroup( visibility = ["//visibility:public"], ) +filegroup( + name = "llvm-as", + srcs = ["bin/llvm-as"], + visibility = ["//visibility:public"], +) + filegroup( name = "llvm-diff", srcs = ["bin/llvm-diff"], @@ -83,6 +89,20 @@ genrule( cmd = "mkdir -p $$(dirname $@) && cp $< $@", ) +genrule( + name = "make_llvm-as", + srcs = select({ + "@llvm//:darwin": [ + "@clang-llvm-10.0.0-x86_64-apple-darwin//:llvm-as", + ], + "//conditions:default": [ + "@clang-llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04//:llvm-as", + ], + }), + outs = ["bin/llvm-as"], + cmd = "mkdir -p $$(dirname $@) && cp $< $@", +) + genrule( name = "make_llvm-diff", srcs = select({ diff --git a/docs/source/compiler_gym/datasets.rst b/docs/source/compiler_gym/datasets.rst index 237e5139e1..77dc1be2b1 100644 --- a/docs/source/compiler_gym/datasets.rst +++ b/docs/source/compiler_gym/datasets.rst @@ -1,35 +1,54 @@ compiler_gym.datasets ===================== -An instance of a CompilerGym environment uses a benchmark as the program being -optimized. Collections of benchmarks are packaged into datasets, storing -additional metadata such as the license, defined by the -:class:`Dataset ` class. +An instance of a CompilerGym environment uses a :class:`Benchmark +` as the program being optimized. Collections +of benchmarks are packaged into a :class:`Dataset +`, storing additional metadata such as the +license. -A simple filesystem-based scheme is used to manage datasets: +.. contents:: + :local: -* Every top-level directory in an environment's site-data folder is - treated as a "dataset". +.. currentmodule:: compiler_gym.datasets -* A benchmarks.inactive directory contains datasets that the user has - downloaded, but are not used by the environment. Moving a directory - from /benchmarks to /benchmarks.inactive means that the - environment will no longer use it. -* Datasets can be packaged as .tar.bz2 archives and downloaded from - the web or local filesystem. Environments may advertise a list of - available datasets. +Benchmark +--------- -Datasets are packaged for each compiler and stored locally in the filesystem. -The filesystem location can be queries using -:attr:`CompilerEnv.datasets_site_path `: +.. autoclass:: Benchmark + :members: - >>> env = gym.make("llvm-v0") - >>> env.datasets_site_path - /home/user/.local/share/compiler_gym/llvm/10.0.0/bitcode_benchmarks -The :mod:`compiler_gym.bin.datasets` module can be used to download and manage -datasets for an environment. +Datasets +-------- -.. automodule:: compiler_gym.datasets - :members: + .. autoclass:: Datasets + :members: + + +Dataset +------- + +.. autoclass:: Dataset + :members: + + .. automethod:: __init__ + + +TarDataset +---------- + +.. autoclass:: Dataset + :members: + + .. automethod:: __init__ + + +TarDatasetWithManifest +---------------------- + +.. autoclass:: Dataset + :members: + + .. automethod:: __init__ diff --git a/examples/example_compiler_gym_service/README.md b/examples/example_compiler_gym_service/README.md index 640d6f8899..cade0209ea 100644 --- a/examples/example_compiler_gym_service/README.md +++ b/examples/example_compiler_gym_service/README.md @@ -17,7 +17,8 @@ Features: * Enforces the service contract, e.g. `StartSession()` must be called before `EndSession()`, list indices must be in-bounds, etc. * Implements all of the RPC endpoints. -* It has two programs "foo" and "bar". +* It has a single dataset "benchmark://example-v0" with two programs "foo" and + "bar". * It has a static action space with three items: `["a", "b", "c"]`. The action space never changes. Actions never end the episode. * There are two observation spaces: diff --git a/examples/example_compiler_gym_service/__init__.py b/examples/example_compiler_gym_service/__init__.py index 991bf3bd81..d918b1dea0 100644 --- a/examples/example_compiler_gym_service/__init__.py +++ b/examples/example_compiler_gym_service/__init__.py @@ -3,9 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """This module demonstrates how to """ +from typing import Optional + +from compiler_gym.datasets import Benchmark, Dataset from compiler_gym.spaces import Reward from compiler_gym.util.registration import register -from compiler_gym.util.runfiles_path import runfiles_path +from compiler_gym.util.runfiles_path import runfiles_path, site_data_path class RuntimeReward(Reward): @@ -40,6 +43,35 @@ def update(self, action, observations, observation_view): return reward +class ExampleDataset(Dataset): + def __init__(self, *args, **kwargs): + super().__init__( + name="benchmark://example-v0", + license="MIT", + description="An example dataset", + site_data_base=site_data_path("example_dataset"), + ) + self._benchmarks = { + "benchmark://example-v0/foo": Benchmark.from_file_contents( + "benchmark://example-v0/foo", "Ir data".encode("utf-8") + ), + "benchmark://example-v0/bar": Benchmark.from_file_contents( + "benchmark://example-v0/bar", "Ir data".encode("utf-8") + ), + } + + def benchmark_uris(self): + yield from self._benchmarks.keys() + + def benchmark(self, uri: Optional[str] = None): + if uri is None: + return self.random.choice(list(self._benchmarks.values())) + elif uri in self._benchmarks: + return self._benchmarks[uri] + else: + raise LookupError("Unknown program name") + + # Register the example service on module import. After importing this module, # the example-v0 environment will be available to gym.make(...). register( @@ -50,6 +82,7 @@ def update(self, action, observations, observation_view): "examples/example_compiler_gym_service/service_cc/compiler_gym-example-service-cc" ), "rewards": [RuntimeReward()], + "datasets": [ExampleDataset()], }, ) @@ -61,5 +94,6 @@ def update(self, action, observations, observation_view): "examples/example_compiler_gym_service/service_py/compiler_gym-example-service-py" ), "rewards": [RuntimeReward()], + "datasets": [ExampleDataset()], }, ) diff --git a/examples/example_compiler_gym_service/env_tests.py b/examples/example_compiler_gym_service/env_tests.py index 24ece5720b..f174f7b368 100644 --- a/examples/example_compiler_gym_service/env_tests.py +++ b/examples/example_compiler_gym_service/env_tests.py @@ -85,8 +85,8 @@ def test_reward_before_reset(env: CompilerEnv): def test_reset_invalid_benchmark(env: CompilerEnv): """Test requesting a specific benchmark.""" - with pytest.raises(ValueError) as ctx: - env.reset(benchmark="foobar") + with pytest.raises(LookupError) as ctx: + env.reset(benchmark="example-v0/foobar") assert str(ctx.value) == "Unknown program name" @@ -166,7 +166,10 @@ def test_rewards(env: CompilerEnv): def test_benchmarks(env: CompilerEnv): - assert env.benchmarks == ["foo", "bar"] + assert list(env.benchmarks) == [ + "benchmark://example-v0/foo", + "benchmark://example-v0/bar", + ] if __name__ == "__main__": diff --git a/examples/example_compiler_gym_service/service_cc/ExampleService.cc b/examples/example_compiler_gym_service/service_cc/ExampleService.cc index 6ba724e040..b2dcfcfdd2 100644 --- a/examples/example_compiler_gym_service/service_cc/ExampleService.cc +++ b/examples/example_compiler_gym_service/service_cc/ExampleService.cc @@ -27,7 +27,9 @@ template return Status::OK; } -std::vector getBenchmarks() { return {"foo", "bar"}; } +std::vector getBenchmarkUris() { + return {"benchmark://example-v0/foo", "benchmark://example-v0/bar"}; +} std::vector getActionSpaces() { ActionSpace space; @@ -95,7 +97,7 @@ Status ExampleService::StartSession(ServerContext* /* unused*/, const StartSessi // Determine the benchmark to use. std::string benchmark = request->benchmark(); - const auto benchmarks = getBenchmarks(); + const auto benchmarks = getBenchmarkUris(); if (!benchmark.empty() && std::find(benchmarks.begin(), benchmarks.end(), benchmark) == benchmarks.end()) { return Status(StatusCode::INVALID_ARGUMENT, "Unknown program name"); @@ -141,7 +143,7 @@ Status ExampleService::Step(ServerContext* /* unused*/, const StepRequest* reque Status ExampleService::GetBenchmarks(grpc::ServerContext* /*unused*/, const GetBenchmarksRequest* /*unused*/, GetBenchmarksReply* reply) { - const auto benchmarks = getBenchmarks(); + const auto benchmarks = getBenchmarkUris(); *reply->mutable_benchmark() = {benchmarks.begin(), benchmarks.end()}; return Status::OK; } diff --git a/examples/example_compiler_gym_service/service_py/example_service.py b/examples/example_compiler_gym_service/service_py/example_service.py index 24904165ca..f071e150f1 100755 --- a/examples/example_compiler_gym_service/service_py/example_service.py +++ b/examples/example_compiler_gym_service/service_py/example_service.py @@ -33,7 +33,7 @@ logging.basicConfig(level=logging.DEBUG) # The names of the benchmarks that are supported -BENCHMARKS = ["foo", "bar"] +BENCHMARKS = ["benchmark://example-v0/foo", "benchmark://example-v0/bar"] # The list of actions that are supported by this service. This example uses a # static (unchanging) action space, but this could be extended to support a diff --git a/leaderboard/llvm_codesize/eval_policy.py b/leaderboard/llvm_codesize/eval_policy.py index 0da696f73c..4417648e3f 100644 --- a/leaderboard/llvm_codesize/eval_policy.py +++ b/leaderboard/llvm_codesize/eval_policy.py @@ -187,7 +187,8 @@ def main(argv): # Install the required dataset and build the list of benchmarks to # evaluate. env.require_dataset(FLAGS.test_dataset) - benchmarks = sorted([b for b in env.benchmarks if FLAGS.test_dataset in b]) + test_dataset = env.datasets[FLAGS.test_dataset] + benchmarks = sorted(test_dataset.benchmark_uris()) if FLAGS.max_benchmarks: benchmarks = benchmarks[: FLAGS.max_benchmarks] diff --git a/setup.py b/setup.py index 9d816c6b0d..a7ae08374f 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ def get_tag(self): "compiler_gym.datasets", "compiler_gym.envs", "compiler_gym.envs.llvm", + "compiler_gym.envs.llvm.datasets", "compiler_gym.envs.llvm.service", "compiler_gym.envs.llvm.service.passes", "compiler_gym.service", diff --git a/tests/compiler_env_test.py b/tests/compiler_env_test.py index 052bc7b87d..761569825a 100644 --- a/tests/compiler_env_test.py +++ b/tests/compiler_env_test.py @@ -62,7 +62,7 @@ def test_benchmark_constructor_arg(env: CompilerEnv): env = gym.make("llvm-v0", benchmark="cBench-v1/dijkstra") try: - assert env.benchmark == "cBench-v1/dijkstra" + assert env.benchmark == "benchmark://cBench-v1/dijkstra" finally: env.close() diff --git a/tests/datasets/BUILD b/tests/datasets/BUILD new file mode 100644 index 0000000000..3cd6ab1e10 --- /dev/null +++ b/tests/datasets/BUILD @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +load("@rules_python//python:defs.bzl", "py_test") + +py_test( + name = "benchmark_test", + timeout = "short", + srcs = ["benchmark_test.py"], + deps = [ + "//compiler_gym/datasets", + "//tests:test_main", + "//tests/pytest_plugins:common", + ], +) + +py_test( + name = "dataset_test", + srcs = ["dataset_test.py"], + deps = [ + "//compiler_gym/datasets", + "//tests:test_main", + "//tests/pytest_plugins:common", + ], +) diff --git a/tests/datasets/benchmark_test.py b/tests/datasets/benchmark_test.py new file mode 100644 index 0000000000..afed163384 --- /dev/null +++ b/tests/datasets/benchmark_test.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Unit tests for //compiler_gym/datasets:benchmark.""" +import pytest + +from compiler_gym.datasets.dataset import Benchmark +from compiler_gym.service.proto import Benchmark as BenchmarkProto +from tests.test_main import main + +pytest_plugins = ["tests.pytest_plugins.common"] + + +def test_benchmark_attribute_outside_init(): + """Test that new attributes cannot be added to Benchmark.""" + benchmark = Benchmark(None) + with pytest.raises(AttributeError): + benchmark.foobar = 123 # noqa + + +def test_benchmark_subclass_attribute_outside_init(): + """Test that new attributes can be added to Benchmark subclass.""" + + class TestBenchmark(Benchmark): + pass + + benchmark = TestBenchmark(None) + benchmark.foobar = 123 + assert benchmark.foobar == 123 + + +def test_benchmark_properties(): + """Test benchmark properties.""" + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + assert benchmark.uri == "benchmark://example-v0/foobar" + assert benchmark.proto == BenchmarkProto(uri="benchmark://example-v0/foobar") + + +def test_benchmark_immutable(): + """Test that benchmark properties are immutable.""" + benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) + with pytest.raises(AttributeError): + benchmark.uri = 123 + with pytest.raises(AttributeError): + benchmark.proto = 123 + + +if __name__ == "__main__": + main() diff --git a/tests/datasets/dataset_test.py b/tests/datasets/dataset_test.py new file mode 100644 index 0000000000..eaded8665e --- /dev/null +++ b/tests/datasets/dataset_test.py @@ -0,0 +1,134 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Unit tests for //compiler_gym/datasets.""" +import re +from pathlib import Path + +import pytest + +from compiler_gym.datasets.dataset import BENCHMARK_URI_RE, DATASET_NAME_RE, Dataset +from tests.test_main import main + +pytest_plugins = ["tests.pytest_plugins.common"] + + +def _rgx_match(regex, groupname, string): + match = re.match(regex, string) + assert match, f"Failed to match regex '{regex}' using string '{groupname}'" + return match.group(groupname) + + +@pytest.mark.parametrize("regex", (DATASET_NAME_RE, BENCHMARK_URI_RE)) +def test_benchmark_uri_protocol(regex): + assert not regex.match("B?://cBench-v1/") # Invalid characters + assert not regex.match("cBench-v1/") # Missing protocol + + _rgx_match(regex, "dataset_protocol", "benchmark://cBench-v1/") == "benchmark" + _rgx_match(regex, "dataset_protocol", "Generator13://gen-v11/") == "Generator13" + + +def test_benchmark_uri_dataset(): + assert not BENCHMARK_URI_RE.match("benchmark://cBench-v1") # Missing trailing / + assert not BENCHMARK_URI_RE.match("benchmark://cBench?v0/") # Invalid character + assert not BENCHMARK_URI_RE.match("benchmark://cBench/") # Missing version suffix + + _rgx_match( + BENCHMARK_URI_RE, "dataset_name", "benchmark://cBench-v1/" + ) == "cBench-v1" + _rgx_match(BENCHMARK_URI_RE, "dataset_name", "Generator13://gen-v11/") == "gen-v11" + + +def test_benchmark_dataset_name(): + _rgx_match( + BENCHMARK_URI_RE, "dataset", "benchmark://cBench-v1/" + ) == "benchmark://cBench-v1" + _rgx_match( + BENCHMARK_URI_RE, "dataset", "Generator13://gen-v11/" + ) == "Generator13://gen-v11" + + +def test_benchmark_uri_id(): + assert not BENCHMARK_URI_RE.match("benchmark://cBench-v1/ whitespace") # Whitespace + assert not BENCHMARK_URI_RE.match("benchmark://cBench-v1/\t") # Whitespace + + _rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cBench-v1/") == "" + _rgx_match(BENCHMARK_URI_RE, "benchmark_name", "benchmark://cBench-v1/foo") == "foo" + _rgx_match( + BENCHMARK_URI_RE, "benchmark_name", "benchmark://cBench-v1/foo/123" + ) == "foo/123" + _rgx_match( + BENCHMARK_URI_RE, + "benchmark_name", + "benchmark://cBench-v1/foo/123?param=true&false", + ) == "foo/123?param=true&false" + + +def test_dataset_properties(tmpwd: Path): + class TestDataset(Dataset): + def __init__(self): + super().__init__( + name="benchmark://test-v0", + description="A test dataset", + license="MIT", + site_data_base="test", + ) + + dataset = TestDataset() + assert dataset.name == "benchmark://test-v0" + assert dataset.protocol == "benchmark" + assert dataset.version == 0 + assert dataset.description == "A test dataset" + assert dataset.license == "MIT" + assert dataset.long_description_url is None + + +def test_dataset_site_data_directory(tmpwd: Path): + class TestDataset(Dataset): + def __init__(self): + super().__init__( + name="benchmark://test-v0", + description="A test dataset", + license="MIT", + site_data_base="test", + ) + + dataset = TestDataset() + assert dataset.site_data_path == tmpwd / "test" / "benchmark" / "test-v0" + assert not dataset.site_data_path.is_dir() # Dir is not created until needed. + + +def test_dataset_long_description_url(tmpwd: Path): + class TestDataset(Dataset): + def __init__(self): + super().__init__( + name="benchmark://test-v0", + description="A test dataset", + license="MIT", + long_description_url="https://facebook.com", + site_data_base="test", + ) + + dataset = TestDataset() + assert dataset.long_description_url == "https://facebook.com" + + +def test_dataset_name_missing_version(tmpwd: Path): + class TestDataset(Dataset): + def __init__(self): + super().__init__( + name="benchmark://test", + description="A test dataset", + license="MIT", + site_data_base="test", + ) + + with pytest.raises(ValueError) as e_ctx: + TestDataset() + + assert "Invalid dataset name: 'benchmark://test'" in str(e_ctx.value) + + +if __name__ == "__main__": + main() diff --git a/tests/llvm/BUILD b/tests/llvm/BUILD index e5561f22b4..c8fb123d51 100644 --- a/tests/llvm/BUILD +++ b/tests/llvm/BUILD @@ -85,11 +85,11 @@ py_test( py_test( name = "datasets_test", - timeout = "short", srcs = ["datasets_test.py"], deps = [ "//compiler_gym", "//compiler_gym/envs/llvm:legacy_datasets", + "//compiler_gym/envs/llvm/datasets", "//tests:test_main", "//tests/pytest_plugins:common", "//tests/pytest_plugins:llvm", diff --git a/tests/llvm/autophase_test.py b/tests/llvm/autophase_test.py index 86241c0b21..dbd853e9ba 100644 --- a/tests/llvm/autophase_test.py +++ b/tests/llvm/autophase_test.py @@ -12,6 +12,7 @@ def test_autophase_crc32_feature_vector(env: CompilerEnv): env.benchmark = "cBench-v1/crc32" env.reset() + print(env.benchmark) features = env.observation["AutophaseDict"] print(features) # For debugging on failure. assert features == { @@ -75,4 +76,4 @@ def test_autophase_crc32_feature_vector(env: CompilerEnv): if __name__ == "__main__": - main() + main(debug_level=3) diff --git a/tests/llvm/custom_benchmarks_test.py b/tests/llvm/custom_benchmarks_test.py index a8b42e9337..738a55dd5f 100644 --- a/tests/llvm/custom_benchmarks_test.py +++ b/tests/llvm/custom_benchmarks_test.py @@ -11,8 +11,10 @@ import gym import pytest +from compiler_gym.datasets import Benchmark from compiler_gym.envs import LlvmEnv, llvm -from compiler_gym.service.proto import Benchmark, File +from compiler_gym.service.proto import Benchmark as BenchmarkProto +from compiler_gym.service.proto import File from compiler_gym.util.runfiles_path import runfiles_path from tests.test_main import main @@ -29,12 +31,12 @@ def test_reset_invalid_benchmark(env: LlvmEnv): with pytest.raises(ValueError) as ctx: env.reset(benchmark=invalid_benchmark) - assert str(ctx.value) == f'Unknown benchmark "{invalid_benchmark}"' + assert str(ctx.value) == f"Invalid benchmark URI: 'benchmark://{invalid_benchmark}'" def test_invalid_benchmark_data(env: LlvmEnv): - benchmark = Benchmark( - uri="benchmark://new", program=File(contents="Invalid bitcode".encode("utf-8")) + benchmark = Benchmark.from_file_contents( + "benchmark://new", "Invalid bitcode".encode("utf-8") ) with pytest.raises(ValueError) as ctx: @@ -45,7 +47,9 @@ def test_invalid_benchmark_data(env: LlvmEnv): def test_invalid_benchmark_missing_file(env: LlvmEnv): benchmark = Benchmark( - uri="benchmark://new", + BenchmarkProto( + uri="benchmark://new", + ) ) with pytest.raises(ValueError) as ctx: @@ -57,9 +61,7 @@ def test_invalid_benchmark_missing_file(env: LlvmEnv): def test_benchmark_path_not_found(env: LlvmEnv): with tempfile.TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) - benchmark = Benchmark( - uri="benchmark://new", program=File(uri=f"file:///{tmpdir}/not_found") - ) + benchmark = Benchmark.from_file("benchmark://new", f"{tmpdir}/not_found") with pytest.raises(FileNotFoundError) as ctx: env.reset(benchmark=benchmark) @@ -72,9 +74,7 @@ def test_benchmark_path_empty_file(env: LlvmEnv): tmpdir = Path(tmpdir) (tmpdir / "test.bc").touch() - benchmark = Benchmark( - uri="benchmark://new", program=File(uri=f"file:///{tmpdir}/test.bc") - ) + benchmark = Benchmark.from_file("benchmark://new", tmpdir / "test.bc") with pytest.raises(ValueError) as ctx: env.reset(benchmark=benchmark) @@ -88,9 +88,7 @@ def test_invalid_benchmark_path_contents(env: LlvmEnv): with open(str(tmpdir / "test.bc"), "w") as f: f.write("Invalid bitcode") - benchmark = Benchmark( - uri="benchmark://new", program=File(uri=f"file:///{tmpdir}/test.bc") - ) + benchmark = Benchmark.from_file("benchmark://new", tmpdir / "test.bc") with pytest.raises(ValueError) as ctx: env.reset(benchmark=benchmark) @@ -100,7 +98,10 @@ def test_invalid_benchmark_path_contents(env: LlvmEnv): def test_benchmark_path_invalid_protocol(env: LlvmEnv): benchmark = Benchmark( - uri="benchmark://new", program=File(uri="invalid_protocol://test") + "benchmark://new", + BenchmarkProto( + uri="benchmark://new", program=File(uri="invalid_protocol://test") + ), ) with pytest.raises(ValueError) as ctx: @@ -108,22 +109,18 @@ def test_benchmark_path_invalid_protocol(env: LlvmEnv): assert ( str(ctx.value) - == 'Unsupported benchmark URI protocol: "invalid_protocol://test"' + == 'Invalid benchmark data URI. Only the file:/// protocol is supported: "invalid_protocol://test"' ) def test_custom_benchmark(env: LlvmEnv): - benchmark = Benchmark( - uri="benchmark://new", program=File(uri=f"file:///{EXAMPLE_BITCODE_FILE}") - ) + benchmark = Benchmark.from_file("benchmark://new", EXAMPLE_BITCODE_FILE) env.reset(benchmark=benchmark) assert env.benchmark == "benchmark://new" def test_custom_benchmark_constructor(): - benchmark = Benchmark( - uri="benchmark://new", program=File(uri=f"file:///{EXAMPLE_BITCODE_FILE}") - ) + benchmark = Benchmark.from_file("benchmark://new", EXAMPLE_BITCODE_FILE) env = gym.make("llvm-v0", benchmark=benchmark) try: env.reset() @@ -136,7 +133,7 @@ def test_make_benchmark_single_bitcode(env: LlvmEnv): benchmark = llvm.make_benchmark(EXAMPLE_BITCODE_FILE) assert benchmark.uri == f"file:///{EXAMPLE_BITCODE_FILE}" - assert benchmark.program.uri == f"file:///{EXAMPLE_BITCODE_FILE}" + assert benchmark.proto.program.uri == f"file:///{EXAMPLE_BITCODE_FILE}" env.reset(benchmark=benchmark) assert env.benchmark == benchmark.uri diff --git a/tests/llvm/datasets/BUILD b/tests/llvm/datasets/BUILD new file mode 100644 index 0000000000..5c7f35d1ac --- /dev/null +++ b/tests/llvm/datasets/BUILD @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +load("@rules_python//python:defs.bzl", "py_test") + +py_test( + name = "cBench_test", + timeout = "short", + srcs = ["cBench_test.py"], + deps = [ + "//compiler_gym/envs/llvm/datasets", + "//tests:test_main", + "//tests/pytest_plugins:common", + ], +) + +py_test( + name = "github_test", + timeout = "long", + srcs = ["github_test.py"], + deps = [ + "//compiler_gym/envs/llvm/datasets", + "//tests:test_main", + "//tests/pytest_plugins:common", + ], +) diff --git a/tests/llvm/datasets/cBench_test.py b/tests/llvm/datasets/cBench_test.py new file mode 100644 index 0000000000..8f675b6918 --- /dev/null +++ b/tests/llvm/datasets/cBench_test.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for //compiler_gym/envs/llvm/datasets.""" +import tempfile +from pathlib import Path + +import pytest + +from compiler_gym.envs.llvm.datasets import CBenchDataset +from tests.test_main import main + +pytest_plugins = ["tests.pytest_plugins.common"] + + +@pytest.fixture(scope="module") +def cbench_dataset() -> CBenchDataset: + with tempfile.TemporaryDirectory() as d: + yield CBenchDataset(site_data_base=Path(d)) + + +def test_cBench_count(cbench_dataset: CBenchDataset): + assert cbench_dataset.n == 23 + + +def test_cBench_uris(cbench_dataset: CBenchDataset): + assert list(cbench_dataset.benchmark_uris()) == [ + "benchmark://cBench-v1/adpcm", + "benchmark://cBench-v1/bitcount", + "benchmark://cBench-v1/blowfish", + "benchmark://cBench-v1/bzip2", + "benchmark://cBench-v1/crc32", + "benchmark://cBench-v1/dijkstra", + "benchmark://cBench-v1/ghostscript", + "benchmark://cBench-v1/gsm", + "benchmark://cBench-v1/ispell", + "benchmark://cBench-v1/jpeg-c", + "benchmark://cBench-v1/jpeg-d", + "benchmark://cBench-v1/lame", + "benchmark://cBench-v1/patricia", + "benchmark://cBench-v1/qsort", + "benchmark://cBench-v1/rijndael", + "benchmark://cBench-v1/sha", + "benchmark://cBench-v1/stringsearch", + "benchmark://cBench-v1/stringsearch2", + "benchmark://cBench-v1/susan", + "benchmark://cBench-v1/tiff2bw", + "benchmark://cBench-v1/tiff2rgba", + "benchmark://cBench-v1/tiffdither", + "benchmark://cBench-v1/tiffmedian", + ] + + +if __name__ == "__main__": + main() diff --git a/tests/llvm/datasets/github_test.py b/tests/llvm/datasets/github_test.py new file mode 100644 index 0000000000..bfa1c6951b --- /dev/null +++ b/tests/llvm/datasets/github_test.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for //compiler_gym/envs/llvm/datasets.""" +import tempfile +from pathlib import Path + +import pytest + +from compiler_gym.envs.llvm.datasets import GitHubDataset +from tests.pytest_plugins.common import skip_on_ci +from tests.test_main import main + +pytest_plugins = ["tests.pytest_plugins.common"] + + +@pytest.fixture(scope="module") +def github_dataset() -> GitHubDataset: + with tempfile.TemporaryDirectory() as d: + yield GitHubDataset(site_data_base=Path(d)) + + +@skip_on_ci +def test_github_count(github_dataset: GitHubDataset): + assert github_dataset.n == 50708 + + +if __name__ == "__main__": + main() diff --git a/tests/llvm/datasets_test.py b/tests/llvm/datasets_test.py index adbb7f9f5b..aeacdc5f35 100644 --- a/tests/llvm/datasets_test.py +++ b/tests/llvm/datasets_test.py @@ -27,41 +27,6 @@ def test_validate_sha_output_invalid(): assert legacy_datasets.validate_sha_output(output) -def test_default_cBench_dataset_require(tmpwd, temporary_environ): - """Test that cBench is downloaded.""" - del temporary_environ - - os.environ["COMPILER_GYM_SITE_DATA"] = str(tmpwd / "site_data") - env = gym.make("llvm-v0") - try: - assert not env.benchmarks, "Sanity check" - - # Datasaet is downloaded. - assert env.require_dataset("cBench-v1") - assert env.benchmarks - - # Dataset is already downloaded. - assert not env.require_dataset("cBench-v1") - finally: - env.close() - - -def test_default_cBench_on_reset(tmpwd, temporary_environ): - """Test that cBench is downloaded by default when no benchmarks are available.""" - del temporary_environ - - os.environ["COMPILER_GYM_SITE_DATA"] = str(tmpwd / "site_data") - env = gym.make("llvm-v0") - try: - assert not env.benchmarks, "Sanity check" - - env.reset() - assert env.benchmarks - assert env.benchmark.startswith("benchmark://cBench-v1/") - finally: - env.close() - - @pytest.mark.parametrize("benchmark_name", ["benchmark://npb-v0/1", "npb-v0/1"]) def test_dataset_required(tmpwd, temporary_environ, benchmark_name): """Test that the required dataset is downlaoded when a benchmark is specified.""" @@ -71,13 +36,12 @@ def test_dataset_required(tmpwd, temporary_environ, benchmark_name): env = gym.make("llvm-v0") try: env.reset(benchmark=benchmark_name) - - assert env.benchmarks assert env.benchmark.startswith("benchmark://npb-v0/") finally: env.close() +@pytest.mark.xfail(strict=True) def test_cBench_v0_deprecation(env: LlvmEnv): """Test that cBench-v0 emits a deprecation warning when used.""" with pytest.deprecated_call( diff --git a/tests/llvm/llvm_benchmarks_test.py b/tests/llvm/llvm_benchmarks_test.py index 076385c2ae..d5c43e0346 100644 --- a/tests/llvm/llvm_benchmarks_test.py +++ b/tests/llvm/llvm_benchmarks_test.py @@ -8,8 +8,10 @@ import pytest +from compiler_gym.datasets import Benchmark from compiler_gym.envs import CompilerEnv -from compiler_gym.service.proto import Benchmark, File +from compiler_gym.service.proto import Benchmark as BenchmarkProto +from compiler_gym.service.proto import File from tests.test_main import main pytest_plugins = ["tests.pytest_plugins.llvm"] @@ -19,12 +21,14 @@ def test_add_benchmark_invalid_protocol(env: CompilerEnv): with pytest.raises(ValueError) as ctx: env.reset( benchmark=Benchmark( - uri="benchmark://foo", program=File(uri="https://invalid/protocol") + BenchmarkProto( + uri="benchmark://foo", program=File(uri="https://invalid/protocol") + ), ) ) - assert ( - str(ctx.value) - == 'Unsupported benchmark URI protocol: "https://invalid/protocol"' + assert str(ctx.value) == ( + "Invalid benchmark data URI. " + 'Only the file:/// protocol is supported: "https://invalid/protocol"' ) @@ -32,11 +36,7 @@ def test_add_benchmark_invalid_path(env: CompilerEnv): with tempfile.TemporaryDirectory() as d: tmp = Path(d) / "not_a_file" with pytest.raises(FileNotFoundError) as ctx: - env.reset( - benchmark=Benchmark( - uri="benchmark://foo", program=File(uri=f"file:///{tmp}") - ) - ) + env.reset(benchmark=Benchmark.from_file("benchmark://foo", tmp)) assert str(ctx.value) == f'File not found: "{tmp}"' diff --git a/tests/llvm/llvm_env_test.py b/tests/llvm/llvm_env_test.py index 7c8a59b493..a0c625e3fe 100644 --- a/tests/llvm/llvm_env_test.py +++ b/tests/llvm/llvm_env_test.py @@ -92,26 +92,26 @@ def test_commandline(env: CompilerEnv): ] -def test_uri_substring_candidate_match(env: CompilerEnv): +def test_uri_substring_no_match(env: CompilerEnv): env.reset(benchmark="benchmark://cBench-v1/crc32") assert env.benchmark == "benchmark://cBench-v1/crc32" - env.reset(benchmark="benchmark://cBench-v1/crc3") - assert env.benchmark == "benchmark://cBench-v1/crc32" + with pytest.raises(LookupError): + env.reset(benchmark="benchmark://cBench-v1/crc3") - env.reset(benchmark="benchmark://cBench-v1/cr") - assert env.benchmark == "benchmark://cBench-v1/crc32" + with pytest.raises(LookupError): + env.reset(benchmark="benchmark://cBench-v1/cr") -def test_uri_substring_candidate_match_infer_protocol(env: CompilerEnv): +def test_uri_substring_candidate_no_match_infer_protocol(env: CompilerEnv): env.reset(benchmark="cBench-v1/crc32") assert env.benchmark == "benchmark://cBench-v1/crc32" - env.reset(benchmark="cBench-v1/crc3") - assert env.benchmark == "benchmark://cBench-v1/crc32" + with pytest.raises(LookupError): + env.reset(benchmark="cBench-v1/crc3") - env.reset(benchmark="cBench-v1/cr") - assert env.benchmark == "benchmark://cBench-v1/crc32" + with pytest.raises(LookupError): + env.reset(benchmark="cBench-v1/cr") def test_reset_to_force_benchmark(env: CompilerEnv): @@ -157,7 +157,7 @@ def test_change_benchmark_mid_episode(env: LlvmEnv): def test_set_benchmark_invalid_type(env: LlvmEnv): with pytest.raises(TypeError) as ctx: env.benchmark = 10 - assert str(ctx.value) == "Unsupported benchmark type: int" + assert str(ctx.value) == "Expected a Benchmark instance, received: 'int'" def test_gym_make_kwargs():