Skip to content

Commit

Permalink
Introduce TrialRunner Abstraction (#720)
Browse files Browse the repository at this point in the history
This is another step in adding support for parallel trial execution
#380.

Here we separate out the running of an individual trial to a single
class - TrialRunner.

Multiple TrialRunners are instantiated at CLI invocation with the
`--num-trial-runners` argument. 
Each TrialRunner associated with a single copy of the root Environment
and its Services (important for eventual parallelization reasons), and
made unique by means of a unique `trial_runner_id` value that's also
included in that Environment's global_config.

In future PRs we will add:
- New Scheduler implementations to run TrialRunners in parallel.
- Async polling of status results in each TrialRunner independently.

---------

Co-authored-by: Sergiy Matusevych <sergiym@microsoft.com>
Co-authored-by: Sergiy Matusevych <sergiy.matusevych@gmail.com>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent 7cc74fd commit 6ffe546
Show file tree
Hide file tree
Showing 33 changed files with 911 additions and 155 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def pytest_configure(config: pytest.Config) -> None:
# Set pandas display options to make inline tests stable.
import pandas # pylint: disable=import-outside-toplevel

pandas.options.display.width = 120
pandas.options.display.width = 150
pandas.options.display.max_columns = 10

# Create a temporary directory for sharing files between master and worker nodes.
Expand Down
2 changes: 2 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def setup(app: SphinxApp) -> None:
("py:class", "numpy.typing.NDArray"),
# External classes that refuse to resolve:
("py:class", "contextlib.nullcontext"),
("py:class", "sqlalchemy.engine.Connection"),
("py:class", "sqlalchemy.engine.Engine"),
("py:class", "sqlalchemy.schema.Table"),
("py:class", "sqlalchemy.MetaData"),
("py:exc", "jsonschema.exceptions.SchemaError"),
("py:exc", "jsonschema.exceptions.ValidationError"),
Expand Down
7 changes: 7 additions & 0 deletions mlos_bench/mlos_bench/config/schemas/cli/cli-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@
"examples": [3, 5]
},

"num_trial_runners": {
"description": "Number of trial runner instances to use to execute benchmark environments. Individual TrialRunners can be identified in configs with $trial_runner_id and optionally run in parallel.",
"type": "integer",
"minimum": 1,
"examples": [1, 3, 5, 10]
},

"storage": {
"description": "Path to the json config describing the storage backend to use.",
"$ref": "#/$defs/json_config_path"
Expand Down
35 changes: 31 additions & 4 deletions mlos_bench/mlos_bench/environments/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ class Environment(ContextManager, metaclass=abc.ABCMeta):
# pylint: disable=too-many-instance-attributes
"""An abstract base of all benchmark environments."""

# Should be provided by the runtime.
_COMMON_CONST_ARGS = {
"trial_runner_id",
}
_COMMON_REQ_ARGS = {
"experiment_id",
"trial_id",
}

@classmethod
def new( # pylint: disable=too-many-arguments
cls,
Expand Down Expand Up @@ -113,6 +122,7 @@ def __init__( # pylint: disable=too-many-arguments
An optional service object (e.g., providing methods to
deploy or reboot a VM/Host, etc.).
"""
global_config = global_config or {}
self._validate_json_config(config, name)
self.name = name
self.config = config
Expand All @@ -122,6 +132,10 @@ def __init__( # pylint: disable=too-many-arguments
self._in_context = False
self._const_args: dict[str, TunableValue] = config.get("const_args", {})

# Make some usual runtime arguments available for tests.
for arg in self._COMMON_CONST_ARGS | self._COMMON_REQ_ARGS:
global_config.setdefault(arg, self._const_args.get(arg, None))

if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug(
"Environment: '%s' Service: %s",
Expand Down Expand Up @@ -149,11 +163,12 @@ def __init__( # pylint: disable=too-many-arguments
self._tunable_params = tunables.subgroup(groups)

# If a parameter comes from the tunables, do not require it in the const_args or globals
req_args = set(config.get("required_args", [])) - set(
self._tunable_params.get_param_values().keys()
req_args = (
set(config.get("required_args", [])) - self._tunable_params.get_param_values().keys()
)
req_args.update(self._COMMON_REQ_ARGS | self._COMMON_CONST_ARGS)
merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args)
self._const_args = self._expand_vars(self._const_args, global_config or {})
self._const_args = self._expand_vars(self._const_args, global_config)

self._params = self._combine_tunables(self._tunable_params)
_LOG.debug("Parameters for '%s' :: %s", name, self._params)
Expand Down Expand Up @@ -321,6 +336,18 @@ def tunable_params(self) -> TunableGroups:
"""
return self._tunable_params

@property
def const_args(self) -> dict[str, TunableValue]:
"""
Get the constant arguments for this Environment.
Returns
-------
parameters : Dict[str, TunableValue]
Key/value pairs of all environment const_args parameters.
"""
return self._const_args.copy()

@property
def parameters(self) -> dict[str, TunableValue]:
"""
Expand All @@ -334,7 +361,7 @@ def parameters(self) -> dict[str, TunableValue]:
Key/value pairs of all environment parameters
(i.e., `const_args` and `tunable_params`).
"""
return self._params
return self._params.copy()

def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool:
"""
Expand Down
88 changes: 69 additions & 19 deletions mlos_bench/mlos_bench/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer
from mlos_bench.schedulers.base_scheduler import Scheduler
from mlos_bench.schedulers.trial_runner import TrialRunner
from mlos_bench.services.base_service import Service
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.services.local.local_exec import LocalExecService
Expand All @@ -45,6 +46,7 @@ class Launcher:

def __init__(self, description: str, long_text: str = "", argv: list[str] | None = None):
# pylint: disable=too-many-statements
# pylint: disable=too-complex
# pylint: disable=too-many-locals
_LOG.info("Launch: %s", description)
epilog = """
Expand Down Expand Up @@ -87,8 +89,6 @@ def __init__(self, description: str, long_text: str = "", argv: list[str] | None
log_handler.setFormatter(logging.Formatter(_LOG_FORMAT))
logging.root.addHandler(log_handler)

self._parent_service: Service = LocalExecService(parent=self._config_loader)

# Prepare global_config from a combination of global config files, cli
# configs, and cli args.
args_dict = vars(args)
Expand All @@ -109,6 +109,7 @@ def __init__(self, description: str, long_text: str = "", argv: list[str] | None
args_rest=args_rest,
global_config=cli_config_args,
)
# TODO: Can we generalize these two rules using excluded_cli_args?
# experiment_id is generally taken from --globals files, but we also allow
# overriding it on the CLI.
# It's useful to keep it there explicitly mostly for the --help output.
Expand All @@ -118,13 +119,31 @@ def __init__(self, description: str, long_text: str = "", argv: list[str] | None
# set it via command line
if args.trial_config_repeat_count:
self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count
self.global_config.setdefault("num_trial_runners", 1)
if args.num_trial_runners:
self.global_config["num_trial_runners"] = args.num_trial_runners
if self.global_config["num_trial_runners"] <= 0:
raise ValueError(
f"""Invalid num_trial_runners: {self.global_config["num_trial_runners"]}"""
)
# Ensure that the trial_id is present since it gets used by some other
# configs but is typically controlled by the run optimize loop.
self.global_config.setdefault("trial_id", 1)

self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True)
assert isinstance(self.global_config, dict)

# --service cli args should override the config file values.
service_files: list[str] = config.get("services", []) + (args.service or [])
# Add a LocalExecService as the parent service for all other services.
self._parent_service: Service = LocalExecService(parent=self._config_loader)
assert isinstance(self._parent_service, SupportsConfigLoading)
self._parent_service = self._parent_service.load_services(
service_files,
self.global_config,
self._parent_service,
)

self.storage = self._load_storage(
args.storage or config.get("storage"),
lazy_schema_create=False if args.create_update_storage_schema_only else None,
Expand All @@ -136,31 +155,34 @@ def __init__(self, description: str, long_text: str = "", argv: list[str] | None
self.storage.update_schema()
sys.exit(0)

# --service cli args should override the config file values.
service_files: list[str] = config.get("services", []) + (args.service or [])
assert isinstance(self._parent_service, SupportsConfigLoading)
self._parent_service = self._parent_service.load_services(
service_files,
self.global_config,
self._parent_service,
)

env_path = args.environment or config.get("environment")
if not env_path:
_LOG.error("No environment config specified.")
parser.error(
"At least the Environment config must be specified."
+ " Run `mlos_bench --help` and consult `README.md` for more info."
" Run `mlos_bench --help` and consult `README.md` for more info."
)
self.root_env_config = self._config_loader.resolve_path(env_path)

self.environment: Environment = self._config_loader.load_environment(
self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service
# Create the TrialRunners and their Environments and Services from the JSON files.
self.trial_runners = TrialRunner.create_from_json(
config_loader=self._config_loader,
global_config=self.global_config,
svcs_json=service_files,
env_json=self.root_env_config,
num_trial_runners=self.global_config["num_trial_runners"],
)

_LOG.info(
"Init %d trial runners for environments: %s",
len(self.trial_runners),
[trial_runner.environment for trial_runner in self.trial_runners],
)
_LOG.info("Init environment: %s", self.environment)

# NOTE: Init tunable values *after* the Environment, but *before* the Optimizer
# NOTE: Init tunable values *after* the Environment(s), but *before* the Optimizer
# TODO: should we assign the same or different tunables for all TrialRunner Environments?
self.tunables = self._init_tunable_values(
self.trial_runners[0].environment,
args.random_init or config.get("random_init", False),
config.get("random_seed") if args.random_seed is None else args.random_seed,
config.get("tunable_values", []) + (args.tunable_values or []),
Expand All @@ -183,6 +205,21 @@ def config_loader(self) -> ConfigPersistenceService:
"""Get the config loader service."""
return self._config_loader

@property
def root_environment(self) -> Environment:
"""
Gets the root (prototypical) Environment from the first TrialRunner.
Note: All TrialRunners have the same Environment config and are made
unique by their use of the unique trial_runner_id assigned to each
TrialRunner's Environment's global_config.
Notes
-----
This is mostly for convenience and backwards compatibility.
"""
return self.trial_runners[0].environment

@property
def service(self) -> Service:
"""Get the parent service."""
Expand Down Expand Up @@ -287,6 +324,18 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None:
),
)

parser.add_argument(
"--num_trial_runners",
"--num-trial-runners",
required=False,
type=int,
help=(
"Number of TrialRunners to use for executing benchmark Environments. "
"Individual TrialRunners can be identified in configs with $trial_runner_id "
"and optionally run in parallel."
),
)

path_args_tracker.add_argument(
"--scheduler",
required=False,
Expand Down Expand Up @@ -449,14 +498,15 @@ def _load_config(

def _init_tunable_values(
self,
env: Environment,
random_init: bool,
seed: int | None,
args_tunables: str | None,
) -> TunableGroups:
"""Initialize the tunables and load key/value pairs of the tunable values from
given JSON files, if specified.
"""
tunables = self.environment.tunable_params
tunables = env.tunable_params
_LOG.debug("Init tunables: default = %s", tunables)

if random_init:
Expand Down Expand Up @@ -561,7 +611,7 @@ def _load_scheduler(self, args_scheduler: str | None) -> Scheduler:
"teardown": self.teardown,
},
global_config=self.global_config,
environment=self.environment,
trial_runners=self.trial_runners,
optimizer=self.optimizer,
storage=self.storage,
root_env_config=self.root_env_config,
Expand All @@ -571,7 +621,7 @@ def _load_scheduler(self, args_scheduler: str | None) -> Scheduler:
return self._config_loader.build_scheduler(
config=class_config,
global_config=self.global_config,
environment=self.environment,
trial_runners=self.trial_runners,
optimizer=self.optimizer,
storage=self.storage,
root_env_config=self.root_env_config,
Expand Down
Loading

0 comments on commit 6ffe546

Please sign in to comment.