Skip to content

Commit

Permalink
WIP: New dataset class.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Mar 2, 2021
1 parent 804fead commit da7a330
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
7 changes: 7 additions & 0 deletions compiler_gym/datasets/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@ py_library(
],
)

py_library(
name = "benchmark",
srcs = ["benchmark.py"],
visibility = ["//compiler_gym:__subpackages__"],
)

py_library(
name = "dataset",
srcs = ["dataset.py"],
visibility = ["//compiler_gym:__subpackages__"],
deps = [
":benchmark",
"//compiler_gym/util",
],
)
90 changes: 89 additions & 1 deletion compiler_gym/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,103 @@
import io
import json
import os
import re
import shutil
import tarfile
from pathlib import Path
from typing import List, NamedTuple, Optional, Union
from typing import Callable, Iterable, List, NamedTuple, Optional, Union

import fasteners

from compiler_gym.datasets.benchmark import Benchmark
from compiler_gym.util.download import download

# Regular expression that matches the full three-part format of a benchmark URI:
# <protocol>://<dataset>/<id>
#
# E.g. "benchmark://foo-v0/" or "benchmark://foo-v0/program".
BENCHMARK_URI_RE = re.compile(
r"(?P<prefix>(?P<protocol>[a-zA-z0-9-_]+)://(?P<dataset>[a-zA-z0-9-_]+-v[0-9]+))/(?P<id>[^\s]*)$"
)


class Dataset(object):
def __init__(
self, name: str, description: str, license: str, protocol: str = "benchmark://"
):
# TODO: Sanity check dataset name.
self._name = name
self._description = description
self._license = license
self._protocol = protocol

@property
def name(self) -> str:
return self._name

@property
def description(self) -> str:
return self._description

@property
def license(self) -> str:
return self._license

@property
def protocol(self) -> str:
return self._protocol

def install(path_to_use: Path) -> bool:
pass

def benchmark_ids(self) -> Iterable[str]:
"""Return an iterator over benchmark IDs that must be consistent
across runs.
The order of the IDs must be consistent across runs.
"""
raise NotImplementedError("abstract class")

def benchmarks(self) -> Iterable[Benchmark]:
"""Possibly lazy list of benchmarks."""
# Default implementation. Subclasses may which to provide an optimized
# version.
for benchmark_id in self.benchmark_ids():
yield self.benchmark(benchmark_id)

def benchmark(self, uri: Optional[str] = None) -> Benchmark:
"""
:raise LookupError: If :code:`uri` is provided but does not exist.
"""
raise NotImplementedError("abstract class")


def get_dataset_benchmark_dispatcher(
datasets: Iterable[Dataset],
) -> Callable[[str], Benchmark]:
"""Return a function that.
:raises ValueError: If the benchmark format was invalid.
"""
dataset_lookup_table = {
f"{dataset.protocol}://{dataset.name}" for dataset in datasets
}

def dispatcher(uri: str) -> Dataset:
if "://" not in uri:
uri = "benchmark://{uri}"
match = BENCHMARK_URI_RE.match(uri)
if not match:
raise ValueError(f"Invalid URI format for benchmark: {uri}")

prefix = match.group("prefix")
if prefix not in dataset_lookup_table:
raise LookupError(f"No dataset found for URI: {prefix}")

return dataset_lookup_table[prefix].benchmark(match.group("id"))

return dispatcher


class LegacyDataset(NamedTuple):
"""A collection of benchmarks for use by an environment."""
Expand Down

0 comments on commit da7a330

Please sign in to comment.