Skip to content

Commit

Permalink
fix: logger accepts arrays as data and use tensorboard_logger Logger (#…
Browse files Browse the repository at this point in the history
…131)

* fix: jax and numpy arrays logging error and Tensorboard Logger issue.
  • Loading branch information
Your-Cheese authored Dec 4, 2024
1 parent 5c48fea commit 0dcf410
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
default_stages: [ "commit", "commit-msg", "push" ]
default_stages: [ "pre-commit", "commit-msg", "pre-push" ]
default_language_version:
python: python3

Expand Down Expand Up @@ -62,7 +62,7 @@ repos:
- id: commitlint
name: "Commit linter"
stages: [ commit-msg ]
additional_dependencies: [ '@commitlint/config-conventional' ]
additional_dependencies: [ "@commitlint/cli",'@commitlint/config-conventional' ]

- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
Expand Down
18 changes: 14 additions & 4 deletions stoix/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import jax
import neptune
import numpy as np
import tensorboard_logger
import wandb
from colorama import Fore, Style
from jax.typing import ArrayLike
from marl_eval.json_tools import JsonLogger as MarlEvalJsonLogger
from neptune.utils import stringify_unsupported
from omegaconf import DictConfig
from pandas.io.json._normalize import _simple_json_normalize as flatten_dict
from tensorboard_logger import configure, log_value


class LogEvent(Enum):
Expand Down Expand Up @@ -61,6 +61,10 @@ def log(self, metrics: Dict, t: int, t_eval: int, event: LogEvent) -> None:
# {metric1_name: {mean: metric, max: metric, ...}, metric2_name: ...}
metrics = jax.tree_util.tree_map(describe, metrics)

metrics = jax.tree.map(
lambda x: x.item() if isinstance(x, (jax.Array, np.ndarray)) else x, metrics
)

self.logger.log_dict(metrics, t, t_eval, event)

def calc_solve_rate(self, episode_metrics: Dict, event: LogEvent) -> Dict:
Expand Down Expand Up @@ -105,7 +109,13 @@ def log_dict(self, data: Dict, step: int, eval_step: int, event: LogEvent) -> No
data = flatten_dict(data, sep="/")

for key, value in data.items():
self.log_stat(key, value, step, eval_step, event)
self.log_stat(
key,
value,
step,
eval_step,
event,
)

def stop(self) -> None:
"""Stop the logger."""
Expand Down Expand Up @@ -230,8 +240,8 @@ def __init__(self, cfg: DictConfig, unique_token: str) -> None:
tb_exp_path = get_logger_path(cfg, "tensorboard")
tb_logs_path = os.path.join(cfg.logger.base_exp_path, f"{tb_exp_path}/{unique_token}")

configure(tb_logs_path)
self.log = log_value
self.logger = tensorboard_logger.Logger(tb_logs_path)
self.log = self.logger.log_value

def log_stat(self, key: str, value: float, step: int, eval_step: int, event: LogEvent) -> None:
t = step if event != LogEvent.EVAL else eval_step
Expand Down

0 comments on commit 0dcf410

Please sign in to comment.