10
10
import hydra
11
11
from torchrl ._utils import logger as torchrl_logger
12
12
from torchrl .record import VideoRecorder
13
+ from torchrl .trainers .agents .ppo import AtariPPOTrainer
13
14
14
15
15
16
@hydra .main (config_path = "" , config_name = "config_atari" , version_base = "1.1" )
@@ -28,7 +29,6 @@ def main(cfg: "DictConfig"): # noqa: F821
28
29
from torchrl .objectives import ClipPPOLoss
29
30
from torchrl .objectives .value .advantages import GAE
30
31
from torchrl .record .loggers import generate_exp_name , get_logger
31
- from utils_atari import eval_model , make_parallel_env , make_ppo_models
32
32
33
33
device = "cpu" if not torch .cuda .device_count () else "cuda"
34
34
@@ -40,12 +40,14 @@ def main(cfg: "DictConfig"): # noqa: F821
40
40
test_interval = cfg .logger .test_interval // frame_skip
41
41
42
42
# Create models (check utils_atari.py)
43
- actor , critic = make_ppo_models (cfg .env .env_name )
43
+ actor , critic = AtariPPOTrainer . make_ppo_models (cfg .env .env_name )
44
44
actor , critic = actor .to (device ), critic .to (device )
45
45
46
46
# Create collector
47
47
collector = SyncDataCollector (
48
- create_env_fn = make_parallel_env (cfg .env .env_name , cfg .env .num_envs , "cpu" ),
48
+ create_env_fn = AtariPPOTrainer .make_parallel_env (
49
+ cfg .env .env_name , cfg .env .num_envs , "cpu"
50
+ ),
49
51
policy = actor ,
50
52
frames_per_batch = frames_per_batch ,
51
53
total_frames = total_frames ,
@@ -110,7 +112,9 @@ def main(cfg: "DictConfig"): # noqa: F821
110
112
logger_video = False
111
113
112
114
# Create test environment
113
- test_env = make_parallel_env (cfg .env .env_name , 1 , device , is_test = True )
115
+ test_env = AtariPPOTrainer .make_parallel_env (
116
+ cfg .env .env_name , 1 , device , is_test = True
117
+ )
114
118
if logger_video :
115
119
test_env = test_env .append_transform (
116
120
VideoRecorder (logger , tag = "rendering/test" , in_keys = ["pixels_int" ])
@@ -223,7 +227,7 @@ def main(cfg: "DictConfig"): # noqa: F821
223
227
) // test_interval :
224
228
actor .eval ()
225
229
eval_start = time .time ()
226
- test_rewards = eval_model (
230
+ test_rewards = AtariPPOTrainer . eval_model (
227
231
actor , test_env , num_episodes = cfg_logger_num_test_episodes
228
232
)
229
233
eval_time = time .time () - eval_start
0 commit comments