Skip to content

Commit 7865501

Browse files
Remove additional sb3 checks
1 parent d0d0cff commit 7865501

9 files changed

+1
-63
lines changed

cleanrl/c51_jax.py

-8
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
117117

118118

119119
if __name__ == "__main__":
120-
import stable_baselines3 as sb3
121120

122-
if sb3.__version__ < "2.0":
123-
raise ValueError(
124-
"""Ongoing migration: run the following command to install the new dependencies:
125-
126-
poetry run pip install "stable_baselines3==2.0.0a1"
127-
"""
128-
)
129121
args = tyro.cli(Args)
130122
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
131123
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"

cleanrl/ddpg_continuous_action.py

-7
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,7 @@ def forward(self, x):
118118

119119

120120
if __name__ == "__main__":
121-
import stable_baselines3 as sb3
122121

123-
if sb3.__version__ < "2.0":
124-
raise ValueError(
125-
"""Ongoing migration: run the following command to install the new dependencies:
126-
poetry run pip install "stable_baselines3==2.0.0a1"
127-
"""
128-
)
129122
args = tyro.cli(Args)
130123
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
131124
if args.track:

cleanrl/ddpg_continuous_action_jax.py

-7
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,7 @@ class TrainState(TrainState):
113113

114114

115115
if __name__ == "__main__":
116-
import stable_baselines3 as sb3
117116

118-
if sb3.__version__ < "2.0":
119-
raise ValueError(
120-
"""Ongoing migration: run the following command to install the new dependencies:
121-
poetry run pip install "stable_baselines3==2.0.0a1"
122-
"""
123-
)
124117
args = tyro.cli(Args)
125118
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
126119
if args.track:

cleanrl/dqn.py

-8
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
109109

110110

111111
if __name__ == "__main__":
112-
import stable_baselines3 as sb3
113112

114-
if sb3.__version__ < "2.0":
115-
raise ValueError(
116-
"""Ongoing migration: run the following command to install the new dependencies:
117-
118-
poetry run pip install "stable_baselines3==2.0.0a1"
119-
"""
120-
)
121113
args = tyro.cli(Args)
122114
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
123115
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"

cleanrl/dqn_jax.py

-9
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
109109

110110

111111
if __name__ == "__main__":
112-
import stable_baselines3 as sb3
113-
114-
if sb3.__version__ < "2.0":
115-
raise ValueError(
116-
"""Ongoing migration: run the following command to install the new dependencies:
117-
118-
poetry run pip install "stable_baselines3==2.0.0a1"
119-
"""
120-
)
121112
args = tyro.cli(Args)
122113
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
123114
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"

cleanrl/sac_continuous_action.py

-8
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,6 @@ def get_action(self, x):
152152

153153

154154
if __name__ == "__main__":
155-
import stable_baselines3 as sb3
156-
157-
if sb3.__version__ < "2.0":
158-
raise ValueError(
159-
"""Ongoing migration: run the following command to install the new dependencies:
160-
poetry run pip install "stable_baselines3==2.0.0a1"
161-
"""
162-
)
163155

164156
args = tyro.cli(Args)
165157
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"

cleanrl/td3_continuous_action.py

-8
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,6 @@ def forward(self, x):
133133

134134

135135
if __name__ == "__main__":
136-
import stable_baselines3 as sb3
137-
138-
if sb3.__version__ < "2.0":
139-
raise ValueError(
140-
"""Ongoing migration: run the following command to install the new dependencies:
141-
poetry run pip install "stable_baselines3==2.0.0a1"
142-
"""
143-
)
144136

145137
args = tyro.cli(Args)
146138
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"

cleanrl/td3_continuous_action_jax.py

-7
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,7 @@ class TrainState(TrainState):
115115

116116

117117
if __name__ == "__main__":
118-
import stable_baselines3 as sb3
119118

120-
if sb3.__version__ < "2.0":
121-
raise ValueError(
122-
"""Ongoing migration: run the following command to install the new dependencies:
123-
poetry run pip install "stable_baselines3==2.0.0a1"
124-
"""
125-
)
126119
args = tyro.cli(Args)
127120
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
128121
if args.track:

docs/get-started/basic-usage.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Currently, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have bee
5050
Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like,
5151
5252
```
53-
poetry run pip install sb3==2.0.0a1
53+
poetry run pip install ==2.0.0a1
5454
```
5555
5656
!!! warning

0 commit comments

Comments
 (0)