Skip to content

Remove SB3 as a necessary module #505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ jobs:
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E "pytest jax"
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1"
- name: Run gymnasium tests
run: poetry run pytest tests/test_classic_control_gymnasium.py
- name: Run core tests with jax
Expand Down Expand Up @@ -78,7 +76,7 @@ jobs:
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E "pytest atari jax"
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
run: poetry run pip install "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
- name: Run gymnasium tests
run: poetry run pytest tests/test_atari_gymnasium.py
- name: Run gymnasium tests with jax
Expand Down Expand Up @@ -134,8 +132,6 @@ jobs:
run: poetry install -E "pytest mujoco dm_control jax"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1"
- name: install mujoco dependencies
run: |
sudo apt-get update && sudo apt-get -y install libgl1-mesa-glx libosmesa6 libglfw3
Expand Down Expand Up @@ -166,8 +162,6 @@ jobs:
# run: poetry install -E "pytest mujoco dm_control jax"
# - name: Downgrade setuptools
# run: poetry run pip install setuptools==59.5.0
# - name: Run gymnasium migration dependencies
# run: poetry run pip install "stable_baselines3==2.0.0a1"
# - name: Run mujoco tests
# run: poetry run pytest tests/test_mujoco.py

Expand Down
12 changes: 2 additions & 10 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import torch.nn as nn
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.buffers import ReplayBuffer


@dataclass
class Args:
Expand Down Expand Up @@ -121,15 +122,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down
16 changes: 4 additions & 12 deletions cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
import torch.nn as nn
import torch.optim as optim
import tyro
from stable_baselines3.common.atari_wrappers import (
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from cleanrl_utils.buffers import ReplayBuffer


@dataclass
Expand Down Expand Up @@ -143,15 +144,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down
15 changes: 4 additions & 11 deletions cleanrl/c51_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
import optax
import tyro
from flax.training.train_state import TrainState
from stable_baselines3.common.atari_wrappers import (
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from cleanrl_utils.buffers import ReplayBuffer


@dataclass
Expand Down Expand Up @@ -144,15 +145,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down
11 changes: 2 additions & 9 deletions cleanrl/c51_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import optax
import tyro
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.buffers import ReplayBuffer


@dataclass
class Args:
Expand Down Expand Up @@ -116,15 +117,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down
10 changes: 2 additions & 8 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.buffers import ReplayBuffer


@dataclass
class Args:
Expand Down Expand Up @@ -117,14 +118,7 @@ def forward(self, x):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = tyro.cli(Args)
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand Down
10 changes: 2 additions & 8 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import optax
import tyro
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.buffers import ReplayBuffer


@dataclass
class Args:
Expand Down Expand Up @@ -112,14 +113,7 @@ class TrainState(TrainState):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = tyro.cli(Args)
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand Down
11 changes: 2 additions & 9 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.buffers import ReplayBuffer


@dataclass
class Args:
Expand Down Expand Up @@ -108,15 +109,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down
15 changes: 4 additions & 11 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.atari_wrappers import (
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from cleanrl_utils.buffers import ReplayBuffer


@dataclass
Expand Down Expand Up @@ -130,15 +131,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down
15 changes: 4 additions & 11 deletions cleanrl/dqn_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
import optax
import tyro
from flax.training.train_state import TrainState
from stable_baselines3.common.atari_wrappers import (
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from cleanrl_utils.buffers import ReplayBuffer


@dataclass
Expand Down Expand Up @@ -136,15 +137,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down
12 changes: 2 additions & 10 deletions cleanrl/dqn_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import optax
import tyro
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.buffers import ReplayBuffer


@dataclass
class Args:
Expand Down Expand Up @@ -108,15 +109,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down
2 changes: 1 addition & 1 deletion cleanrl/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from stable_baselines3.common.atari_wrappers import ( # isort:skip
from cleanrl_utils.atari_wrappers import ( # isort:skip
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
Expand Down
2 changes: 1 addition & 1 deletion cleanrl/ppo_atari_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from stable_baselines3.common.atari_wrappers import ( # isort:skip
from cleanrl_utils.atari_wrappers import ( # isort:skip
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
Expand Down
2 changes: 1 addition & 1 deletion cleanrl/ppo_atari_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from stable_baselines3.common.atari_wrappers import ( # isort:skip
from cleanrl_utils.atari_wrappers import ( # isort:skip
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
Expand Down
Loading
Loading