@@ -20,7 +20,7 @@ def parse_args():
20
20
parser = argparse .ArgumentParser ()
21
21
parser .add_argument ('--exp-name' , type = str , default = os .path .basename (__file__ ).rstrip (".py" ),
22
22
help = 'the name of this experiment' )
23
- parser .add_argument ('--gym-id' , type = str , default = "Pong -v5" ,
23
+ parser .add_argument ('--gym-id' , type = str , default = "Breakout -v5" ,
24
24
help = 'the id of the gym environment' )
25
25
parser .add_argument ('--learning-rate' , type = float , default = 2.5e-4 ,
26
26
help = 'the learning rate of the optimizer' )
@@ -42,7 +42,7 @@ def parse_args():
42
42
help = 'weather to capture videos of the agent performances (check out `videos` folder)' )
43
43
44
44
# Algorithm specific arguments
45
- parser .add_argument ('--num-envs' , type = int , default = 32 ,
45
+ parser .add_argument ('--num-envs' , type = int , default = 8 ,
46
46
help = 'the number of parallel game environments' )
47
47
parser .add_argument ('--num-steps' , type = int , default = 128 ,
48
48
help = 'the number of steps to run in each environment per policy rollout' )
@@ -85,12 +85,20 @@ def __init__(self, env, deque_size=100):
85
85
self .num_envs = getattr (env , "num_envs" , 1 )
86
86
self .episode_returns = None
87
87
self .episode_lengths = None
88
- self .is_vector_env = True
88
+ # get if the env has lives
89
+ self .has_lives = False
90
+ env .reset ()
91
+ info = env .step (np .zeros (self .num_envs , dtype = int ))[- 1 ]
92
+ if info ["lives" ].sum () > 0 :
93
+ self .has_lives = True
94
+ print ("env has lives" )
95
+
89
96
90
97
def reset (self , ** kwargs ):
91
98
observations = super (RecordEpisodeStatistics , self ).reset (** kwargs )
92
99
self .episode_returns = np .zeros (self .num_envs , dtype = np .float32 )
93
100
self .episode_lengths = np .zeros (self .num_envs , dtype = np .int32 )
101
+ self .lives = np .zeros (self .num_envs , dtype = np .int32 )
94
102
self .returned_episode_returns = np .zeros (self .num_envs , dtype = np .float32 )
95
103
self .returned_episode_lengths = np .zeros (self .num_envs , dtype = np .int32 )
96
104
return observations
@@ -103,8 +111,13 @@ def step(self, action):
103
111
self .episode_lengths += 1
104
112
self .returned_episode_returns [:] = self .episode_returns
105
113
self .returned_episode_lengths [:] = self .episode_lengths
106
- self .episode_returns *= (1 - dones )
107
- self .episode_lengths *= (1 - dones )
114
+ all_lives_exhausted = infos ["lives" ] == 0
115
+ if self .has_lives :
116
+ self .episode_returns *= (1 - all_lives_exhausted )
117
+ self .episode_lengths *= (1 - all_lives_exhausted )
118
+ else :
119
+ self .episode_returns *= (1 - dones )
120
+ self .episode_lengths *= (1 - dones )
108
121
infos ["r" ] = self .returned_episode_returns
109
122
infos ["l" ] = self .returned_episode_lengths
110
123
return (
@@ -180,9 +193,14 @@ def get_action_and_value(self, x, action=None):
180
193
device = torch .device ("cuda" if torch .cuda .is_available () and args .cuda else "cpu" )
181
194
182
195
# env setup
183
- envs = envpool .make (args .gym_id , env_type = "gym" , num_envs = args .num_envs )
196
+ envs = envpool .make (
197
+ args .gym_id ,
198
+ env_type = "gym" ,
199
+ num_envs = args .num_envs ,
200
+ episodic_life = True ,
201
+ reward_clip = True ,
202
+ )
184
203
envs .num_envs = args .num_envs
185
- # envs.is_vector_env = True
186
204
envs = RecordEpisodeStatistics (envs )
187
205
assert isinstance (envs .action_space , gym .spaces .Discrete ), "only discrete action space is supported"
188
206
@@ -196,6 +214,7 @@ def get_action_and_value(self, x, action=None):
196
214
rewards = torch .zeros ((args .num_steps , args .num_envs )).to (device )
197
215
dones = torch .zeros ((args .num_steps , args .num_envs )).to (device )
198
216
values = torch .zeros ((args .num_steps , args .num_envs )).to (device )
217
+ avg_returns = deque (maxlen = 10 )
199
218
200
219
# TRY NOT TO MODIFY: start the game
201
220
global_step = 0
@@ -229,10 +248,15 @@ def get_action_and_value(self, x, action=None):
229
248
next_obs , next_done = torch .Tensor (next_obs ).to (device ), torch .Tensor (done ).to (device )
230
249
231
250
for idx , d in enumerate (done ):
232
- if d :
251
+ if d and info [ 'lives' ][ idx ] == 0 :
233
252
print (f"global_step={ global_step } , episodic_return={ info ['r' ][idx ]} " )
253
+ avg_returns .append (info ['r' ][idx ])
254
+ writer .add_scalar ("charts/avg_episodic_return" , np .average (avg_returns ), global_step )
234
255
writer .add_scalar ("charts/episodic_return" , info ["r" ][idx ], global_step )
235
256
writer .add_scalar ("charts/episodic_length" , info ["l" ][idx ], global_step )
257
+ # if np.average(avg_returns) > 17:
258
+ # writer.add_scalar("charts/time", time.time() - start_time, global_step)
259
+ # quit()
236
260
237
261
# bootstrap value if not done
238
262
with torch .no_grad ():
0 commit comments