Skip to content

Commit 2613097

Browse files
committed
pushing changes
1 parent b44a0fd commit 2613097

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

cleanrl/ppo_atari_envpool.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def parse_args():
2020
parser = argparse.ArgumentParser()
2121
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
2222
help='the name of this experiment')
23-
parser.add_argument('--gym-id', type=str, default="Pong-v5",
23+
parser.add_argument('--gym-id', type=str, default="Breakout-v5",
2424
help='the id of the gym environment')
2525
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
2626
help='the learning rate of the optimizer')
@@ -42,7 +42,7 @@ def parse_args():
4242
help='weather to capture videos of the agent performances (check out `videos` folder)')
4343

4444
# Algorithm specific arguments
45-
parser.add_argument('--num-envs', type=int, default=32,
45+
parser.add_argument('--num-envs', type=int, default=8,
4646
help='the number of parallel game environments')
4747
parser.add_argument('--num-steps', type=int, default=128,
4848
help='the number of steps to run in each environment per policy rollout')
@@ -85,12 +85,20 @@ def __init__(self, env, deque_size=100):
8585
self.num_envs = getattr(env, "num_envs", 1)
8686
self.episode_returns = None
8787
self.episode_lengths = None
88-
self.is_vector_env = True
88+
# get if the env has lives
89+
self.has_lives = False
90+
env.reset()
91+
info = env.step(np.zeros(self.num_envs, dtype=int))[-1]
92+
if info["lives"].sum() > 0:
93+
self.has_lives = True
94+
print("env has lives")
95+
8996

9097
def reset(self, **kwargs):
9198
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
9299
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
93100
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
101+
self.lives = np.zeros(self.num_envs, dtype=np.int32)
94102
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
95103
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
96104
return observations
@@ -103,8 +111,13 @@ def step(self, action):
103111
self.episode_lengths += 1
104112
self.returned_episode_returns[:] = self.episode_returns
105113
self.returned_episode_lengths[:] = self.episode_lengths
106-
self.episode_returns *= (1 - dones)
107-
self.episode_lengths *= (1 - dones)
114+
all_lives_exhausted = infos["lives"] == 0
115+
if self.has_lives:
116+
self.episode_returns *= (1 - all_lives_exhausted)
117+
self.episode_lengths *= (1 - all_lives_exhausted)
118+
else:
119+
self.episode_returns *= (1 - dones)
120+
self.episode_lengths *= (1 - dones)
108121
infos["r"] = self.returned_episode_returns
109122
infos["l"] = self.returned_episode_lengths
110123
return (
@@ -180,9 +193,14 @@ def get_action_and_value(self, x, action=None):
180193
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
181194

182195
# env setup
183-
envs = envpool.make(args.gym_id, env_type="gym", num_envs=args.num_envs)
196+
envs = envpool.make(
197+
args.gym_id,
198+
env_type="gym",
199+
num_envs=args.num_envs,
200+
episodic_life=True,
201+
reward_clip=True,
202+
)
184203
envs.num_envs = args.num_envs
185-
# envs.is_vector_env = True
186204
envs = RecordEpisodeStatistics(envs)
187205
assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"
188206

@@ -196,6 +214,7 @@ def get_action_and_value(self, x, action=None):
196214
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
197215
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
198216
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
217+
avg_returns = deque(maxlen=10)
199218

200219
# TRY NOT TO MODIFY: start the game
201220
global_step = 0
@@ -229,10 +248,15 @@ def get_action_and_value(self, x, action=None):
229248
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
230249

231250
for idx, d in enumerate(done):
232-
if d:
251+
if d and info['lives'][idx] == 0:
233252
print(f"global_step={global_step}, episodic_return={info['r'][idx]}")
253+
avg_returns.append(info['r'][idx])
254+
writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step)
234255
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
235256
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
257+
# if np.average(avg_returns) > 17:
258+
# writer.add_scalar("charts/time", time.time() - start_time, global_step)
259+
# quit()
236260

237261
# bootstrap value if not done
238262
with torch.no_grad():

0 commit comments

Comments
 (0)