Skip to content

Commit 9ff053b

Browse files
committed
[Feature] PPOTrainer
ghstack-source-id: ddd00c7ffb309d9fb845cdf8392c46774cb12b01 Pull Request resolved: #2550
1 parent 19dbeeb commit 9ff053b

10 files changed

+636
-459
lines changed

sota-implementations/ppo/config_atari.yaml

-39
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../torchrl/trainers/agents/config_atari.yaml

sota-implementations/ppo/config_mujoco.yaml

-36
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../torchrl/trainers/agents/config_mujoco.yaml

sota-implementations/ppo/ppo_atari.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import hydra
1111
from torchrl._utils import logger as torchrl_logger
1212
from torchrl.record import VideoRecorder
13+
from torchrl.trainers.agents.ppo import AtariPPOTrainer
1314

1415

1516
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
@@ -28,7 +29,6 @@ def main(cfg: "DictConfig"): # noqa: F821
2829
from torchrl.objectives import ClipPPOLoss
2930
from torchrl.objectives.value.advantages import GAE
3031
from torchrl.record.loggers import generate_exp_name, get_logger
31-
from utils_atari import eval_model, make_parallel_env, make_ppo_models
3232

3333
device = "cpu" if not torch.cuda.device_count() else "cuda"
3434

@@ -40,12 +40,14 @@ def main(cfg: "DictConfig"): # noqa: F821
4040
test_interval = cfg.logger.test_interval // frame_skip
4141

4242
# Create models (check utils_atari.py)
43-
actor, critic = make_ppo_models(cfg.env.env_name)
43+
actor, critic = AtariPPOTrainer.make_ppo_models(cfg.env.env_name)
4444
actor, critic = actor.to(device), critic.to(device)
4545

4646
# Create collector
4747
collector = SyncDataCollector(
48-
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, "cpu"),
48+
create_env_fn=AtariPPOTrainer.make_parallel_env(
49+
cfg.env.env_name, cfg.env.num_envs, "cpu"
50+
),
4951
policy=actor,
5052
frames_per_batch=frames_per_batch,
5153
total_frames=total_frames,
@@ -110,7 +112,9 @@ def main(cfg: "DictConfig"): # noqa: F821
110112
logger_video = False
111113

112114
# Create test environment
113-
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
115+
test_env = AtariPPOTrainer.make_parallel_env(
116+
cfg.env.env_name, 1, device, is_test=True
117+
)
114118
if logger_video:
115119
test_env = test_env.append_transform(
116120
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels_int"])
@@ -223,7 +227,7 @@ def main(cfg: "DictConfig"): # noqa: F821
223227
) // test_interval:
224228
actor.eval()
225229
eval_start = time.time()
226-
test_rewards = eval_model(
230+
test_rewards = AtariPPOTrainer.eval_model(
227231
actor, test_env, num_episodes=cfg_logger_num_test_episodes
228232
)
229233
eval_time = time.time() - eval_start

sota-implementations/ppo/ppo_mujoco.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import hydra
1111
from torchrl._utils import logger as torchrl_logger
1212
from torchrl.record import VideoRecorder
13+
from torchrl.trainers.agents.ppo import ContinuousControlPPOTrainer
1314

1415

1516
@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
@@ -28,7 +29,6 @@ def main(cfg: "DictConfig"): # noqa: F821
2829
from torchrl.objectives import ClipPPOLoss
2930
from torchrl.objectives.value.advantages import GAE
3031
from torchrl.record.loggers import generate_exp_name, get_logger
31-
from utils_mujoco import eval_model, make_env, make_ppo_models
3232

3333
device = "cpu" if not torch.cuda.device_count() else "cuda"
3434
num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
@@ -39,12 +39,12 @@ def main(cfg: "DictConfig"): # noqa: F821
3939
)
4040

4141
# Create models (check utils_mujoco.py)
42-
actor, critic = make_ppo_models(cfg.env.env_name)
42+
actor, critic = ContinuousControlPPOTrainer.make_ppo_models(cfg.env.env_name)
4343
actor, critic = actor.to(device), critic.to(device)
4444

4545
# Create collector
4646
collector = SyncDataCollector(
47-
create_env_fn=make_env(cfg.env.env_name, device),
47+
create_env_fn=ContinuousControlPPOTrainer.make_env(cfg.env.env_name, device),
4848
policy=actor,
4949
frames_per_batch=cfg.collector.frames_per_batch,
5050
total_frames=cfg.collector.total_frames,
@@ -102,7 +102,9 @@ def main(cfg: "DictConfig"): # noqa: F821
102102
logger_video = False
103103

104104
# Create test environment
105-
test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video)
105+
test_env = ContinuousControlPPOTrainer.make_env(
106+
cfg.env.env_name, device, from_pixels=logger_video
107+
)
106108
if logger_video:
107109
test_env = test_env.append_transform(
108110
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
@@ -216,7 +218,7 @@ def main(cfg: "DictConfig"): # noqa: F821
216218
) // cfg_logger_test_interval:
217219
actor.eval()
218220
eval_start = time.time()
219-
test_rewards = eval_model(
221+
test_rewards = ContinuousControlPPOTrainer.eval_model(
220222
actor, test_env, num_episodes=cfg_logger_num_test_episodes
221223
)
222224
eval_time = time.time() - eval_start

0 commit comments

Comments
 (0)