Skip to content

Commit

Permalink
Add support for prefix context manager in logger (from #529)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocamonde committed Sep 26, 2022
1 parent 32467bf commit e56c0e3
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 83 deletions.
148 changes: 100 additions & 48 deletions src/imitation/util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,14 @@
import contextlib
import datetime
import os
import sys
import tempfile
from typing import Any, Dict, Generator, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union

import stable_baselines3.common.logger as sb_logger

from imitation.data import types


def make_output_format(
_format: str,
log_dir: str,
log_suffix: str = "",
max_length: int = 40,
) -> sb_logger.KVWriter:
"""Returns a logger for the requested format.
Args:
_format: the requested format to log to
('stdout', 'log', 'json' or 'csv' or 'tensorboard').
log_dir: the logging directory.
log_suffix: the suffix for the log file.
max_length: the maximum length beyond which the keys get truncated.
Returns:
the logger.
"""
os.makedirs(log_dir, exist_ok=True)
if _format == "stdout":
return sb_logger.HumanOutputFormat(sys.stdout, max_length=max_length)
elif _format == "log":
return sb_logger.HumanOutputFormat(
os.path.join(log_dir, f"log{log_suffix}.txt"),
max_length=max_length,
)
else:
return sb_logger.make_output_format(_format, log_dir, log_suffix)


def _build_output_formats(
folder: str,
format_strs: Sequence[str],
Expand All @@ -50,7 +19,7 @@ def _build_output_formats(
Args:
folder: Path to directory that logs are written to.
format_strs: A list of output format strings. For details on available
format_strs: An list of output format strings. For details on available
output formats see `stable_baselines3.logger.make_output_format`.
Returns:
Expand All @@ -62,7 +31,7 @@ def _build_output_formats(
if f == "wandb":
output_formats.append(WandbOutputFormat())
else:
output_formats.append(make_output_format(f, folder))
output_formats.append(sb_logger.make_output_format(f, folder))
return output_formats


Expand All @@ -72,8 +41,57 @@ class HierarchicalLogger(sb_logger.Logger):
`self.accumulate_means` creates a context manager. While in this context,
values are loggged to a sub-logger, with only mean values recorded in the
top-level (root) logger.
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as dir:
... logger: HierarchicalLogger = configure(dir, ('log',))
... # record the key value pair (loss, 1.0) to path `dir`
... # at step 1.
... logger.record("loss", 1.0)
... logger.dump(step=1)
... with logger.accumulate_means("dataset"):
... # record the key value pair `("raw/dataset/entropy", 5.0)` to path
... # `dir/raw/dataset` at step 100
... logger.record("entropy", 5.0)
... logger.dump(step=100)
... # record the key value pair `("raw/dataset/entropy", 6.0)` to path
... # `dir/raw/dataset` at step 200
... logger.record("entropy", 6.0)
... logger.dump(step=200)
... # record the key value pair `("mean/dataset/entropy", 5.5)` to path
... # `dir` at step 1.
... logger.dump(step=1)
... with logger.add_prefix("foo"), logger.accumulate_means("bar"):
... # record the key value pair ("raw/foo/bar/biz", 42.0) to path
... # `dir/raw/foo/bar` at step 2000
... logger.record("biz", 42.0)
... logger.dump(step=2000)
... # record the key value pair `("mean/foo/bar/biz", 42.0)` to path
... # `dir` at step 1.
... logger.dump(step=1)
... with open(os.path.join(dir, 'log.txt')) as f:
... print(f.read())
-------------------
| loss | 1 |
-------------------
---------------------------------
| mean/ | |
| dataset/entropy | 5.5 |
---------------------------------
-----------------------------
| mean/ | |
| foo/bar/biz | 42 |
-----------------------------
<BLANKLINE>
"""

default_logger: sb_logger.Logger
current_logger: Optional[sb_logger.Logger]
_cached_loggers: Dict[str, sb_logger.Logger]
_prefixes: List[str]
_subdir: Optional[str]
_name: Optional[str]

def __init__(
self,
default_logger: sb_logger.Logger,
Expand All @@ -93,7 +111,9 @@ def __init__(
self.default_logger = default_logger
self.current_logger = None
self._cached_loggers = {}
self._prefixes = []
self._subdir = None
self._name = None
self.format_strs = format_strs
super().__init__(folder=self.default_logger.dir, output_formats=[])

Expand All @@ -103,27 +123,56 @@ def _update_name_to_maps(self) -> None:
self.name_to_excluded = self._logger.name_to_excluded

@contextlib.contextmanager
def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]:
def add_prefix(self, prefix: str) -> Generator[None, None, None]:
"""Add a prefix to the subdirectory used to accumulate means.
This prefix only applies when a `accumulate_means` context is active. If there
are multiple active prefixes, then they are concatenated.
Args:
prefix: The prefix to add to the named sub.
Yields:
None when the context manager is entered
Raises:
RuntimeError: if accumulate means context is already active.
"""
if self.current_logger is not None:
raise RuntimeError(
"Cannot add prefix when accumulate_means context is already active.",
)

try:
self._prefixes.append(prefix)
yield
finally:
self._prefixes.pop()

@contextlib.contextmanager
def accumulate_means(self, name: str) -> Generator[None, None, None]:
"""Temporarily modifies this HierarchicalLogger to accumulate means values.
During this context, `self.record(key, value)` writes the "raw" values in
"{self.default_logger.log_dir}/{subdir}" under the key "raw/{subdir}/{key}".
At the same time, any call to `self.record` will also accumulate mean values
on the default logger by calling
`self.default_logger.record_mean(f"mean/{subdir}/{key}", value)`.
Within this context manager, `self.record(key, value)` writes the "raw" values
in `f"{self.default_logger.log_dir}/{prefix}/{name}"` under the key
`"raw/{prefix}/{name}/{key}"`. At the same time, any call to `self.record` will
also accumulate mean values on the default logger by calling
`self.default_logger.record_mean(f"mean/{prefix}/{name}/{key}", value)`.
During the context, `self.record(key, value)` will write the "raw" values in
`"{self.default_logger.log_dir}/subdir"` under the key "raw/{subdir}/key".
Multiple prefixes may be active at once. In this case the `prefix` is simply the
concatenation of each of the active prefixes in the order they
where created e.g. if the active `prefixes` are ['foo', 'bar'] then
the `prefix` is 'foo/bar'.
After the context exits, calling `self.dump()` will write the means
of all the "raw" values accumulated during this context to
`self.default_logger` under keys with the prefix `mean/{subdir}/`
`self.default_logger` under keys of the form `mean/{prefix}/{name}/{key}`
Note that the behavior of other logging methods, `log` and `record_mean`
are unmodified and will go straight to the default logger.
Args:
subdir: A string key which determines the `folder` where raw data is
name: A string key which determines the `folder` where raw data is
written and temporary logging prefixes for raw and mean data. Entering
an `accumulate_means` context in the future with the same `subdir`
will safely append to logs written in this folder rather than
Expand All @@ -139,10 +188,11 @@ def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]
if self.current_logger is not None:
raise RuntimeError("Nested `accumulate_means` context")

subdir = os.path.join(*self._prefixes, name)

if subdir in self._cached_loggers:
logger = self._cached_loggers[subdir]
else:
subdir = types.path_to_str(subdir)
folder = os.path.join(self.default_logger.dir, "raw", subdir)
os.makedirs(folder, exist_ok=True)
output_formats = _build_output_formats(folder, self.format_strs)
Expand All @@ -152,20 +202,22 @@ def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]
try:
self.current_logger = logger
self._subdir = subdir
self._name = name
self._update_name_to_maps()
yield
finally:
self.current_logger = None
self._subdir = None
self._name = None
self._update_name_to_maps()

def record(self, key, val, exclude=None):
if self.current_logger is not None: # In accumulate_means context.
assert self._subdir is not None
raw_key = "/".join(["raw", self._subdir, key])
raw_key = "/".join(["raw", *self._prefixes, self._name, key])
self.current_logger.record(raw_key, val, exclude)

mean_key = "/".join(["mean", self._subdir, key])
mean_key = "/".join(["mean", *self._prefixes, self._name, key])
self.default_logger.record_mean(mean_key, val, exclude)
else: # Not in accumulate_means context.
self.default_logger.record(key, val, exclude)
Expand Down Expand Up @@ -269,4 +321,4 @@ def configure(
default_logger = sb_logger.Logger(folder, list(output_formats))
hier_format_strs = [f for f in format_strs if f != "wandb"]
hier_logger = HierarchicalLogger(default_logger, hier_format_strs)
return hier_logger
return hier_logger
80 changes: 45 additions & 35 deletions tests/util/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Tests `imitation.util.logger`."""

import csv
import json
import os.path as osp
from collections import defaultdict

Expand All @@ -22,40 +21,13 @@ def _csv_to_dict(csv_path: str) -> dict:
return result


def _json_to_dict(json_path: str) -> dict:
r"""Loads the saved json logging file and convert it to expected dict format.
Args:
json_path: Path of the json log file.
Stored in the format - '{"A": 1, "B": 1}\n{"A": 2}\n{"B": 3}\n'
Returns:
dictionary in the format - `{"A": [1, 2, ""], "B": [1, "", 3]}`
"""
result = defaultdict(list)
with open(json_path, "r") as f:
all_line_dicts = [json.loads(line) for line in f.readlines()]
# get all the keys in the dict so as to add "" if the key is not present in a line
all_keys = set().union(*[list(line_dict.keys()) for line_dict in all_line_dicts])

for line_dict in all_line_dicts:
for key in all_keys:
result[key].append(line_dict.get(key, ""))
return result


def _compare_csv_lines(csv_path: str, expect: dict):
observed = _csv_to_dict(csv_path)
assert expect == observed


def _compare_json_lines(json_path: str, expect: dict):
observed = _json_to_dict(json_path)
assert expect == observed


def test_no_accum(tmpdir):
hier_logger = logger.configure(tmpdir, ["csv", "json"])
hier_logger = logger.configure(tmpdir, ["csv"])
assert hier_logger.get_dir() == tmpdir

# Check that the recorded "A": -1 is overwritten by "A": 1 in the next line.
Expand All @@ -71,12 +43,6 @@ def test_no_accum(tmpdir):
hier_logger.dump()
expect = {"A": [1, 2, ""], "B": [1, "", 3]}
_compare_csv_lines(osp.join(tmpdir, "progress.csv"), expect)
_compare_json_lines(osp.join(tmpdir, "progress.json"), expect)


def test_raise_unknown_format():
with pytest.raises(ValueError, match=r"Unknown format specified:.*"):
logger.make_output_format("txt", "log_dir")


def test_free_form(tmpdir):
Expand Down Expand Up @@ -229,3 +195,47 @@ def test_hard(tmpdir):
_compare_csv_lines(osp.join(tmpdir, "progress.csv"), expect_default)
_compare_csv_lines(osp.join(tmpdir, "raw", "gen", "progress.csv"), expect_raw_gen)
_compare_csv_lines(osp.join(tmpdir, "raw", "disc", "progress.csv"), expect_raw_disc)


def test_prefix(tmpdir):
hier_logger = logger.configure(tmpdir)

with hier_logger.add_prefix("foo"), hier_logger.accumulate_means("bar"):
hier_logger.record("A", 1)
hier_logger.record("B", 2)
hier_logger.dump()

hier_logger.record("no_context", 1)

with hier_logger.accumulate_means("blat"):
hier_logger.record("C", 3)
hier_logger.dump()

hier_logger.dump()

expect_raw_foo_bar = {
"raw/foo/bar/A": [1],
"raw/foo/bar/B": [2],
}
expect_raw_blat = {
"raw/blat/C": [3],
}
expect_default = {
"mean/foo/bar/A": [1],
"mean/foo/bar/B": [2],
"mean/blat/C": [3],
"no_context": [1],
}

_compare_csv_lines(osp.join(tmpdir, "progress.csv"), expect_default)
_compare_csv_lines(
osp.join(tmpdir, "raw", "foo", "bar", "progress.csv"),
expect_raw_foo_bar,
)
_compare_csv_lines(osp.join(tmpdir, "raw", "blat", "progress.csv"), expect_raw_blat)


def test_cant_add_prefix_within_accumulate_means(tmpdir):
h = logger.configure(tmpdir)
with pytest.raises(RuntimeError), h.accumulate_means("foo"), h.add_prefix("bar"):
pass # pragma: no cover

0 comments on commit e56c0e3

Please sign in to comment.