From f127aa3d5958917d44184b2cda4583be021d84b6 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 29 May 2022 13:40:57 -0400 Subject: [PATCH 01/27] prototype jax with ddpg --- cleanrl/dd.py | 23 +++ cleanrl/ddpg_continuous_action_jax.py | 287 ++++++++++++++++++++++++++ cleanrl/ddpg_continuous_action_jit.py | 262 +++++++++++++++++++++++ poetry.lock | 207 ++++++++++++++++++- pyproject.toml | 4 + 5 files changed, 782 insertions(+), 1 deletion(-) create mode 100644 cleanrl/dd.py create mode 100644 cleanrl/ddpg_continuous_action_jax.py create mode 100644 cleanrl/ddpg_continuous_action_jit.py diff --git a/cleanrl/dd.py b/cleanrl/dd.py new file mode 100644 index 000000000..67c313cf7 --- /dev/null +++ b/cleanrl/dd.py @@ -0,0 +1,23 @@ +from typing import Sequence + +import numpy as np +import jax +import jax.numpy as jnp +import flax.linen as nn + +class MLP(nn.Module): + features: Sequence[int] + + @nn.compact + def __call__(self, x): + for feat in self.features[:-1]: + x = nn.relu(nn.Dense(feat)(x)) + x = nn.Dense(self.features[-1])(x) + return x + +model = MLP([12, 8, 4]) +batch = jnp.ones((32, 10)) +variables = model.init(jax.random.PRNGKey(0), batch) +# print(variables) +# for _ in range(4000): +# output = diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py new file mode 100644 index 000000000..04ca399c6 --- /dev/null +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -0,0 +1,287 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy +# docs and experiment results can be found at +# https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy + +import argparse +import os +import random +import time +from distutils.util import strtobool +from typing import Sequence + +import gym +import numpy as np +import pybullet_envs # noqa +import torch +import jax +import jax.numpy as jnp +import flax.linen as nn +import optax +import torch.nn.functional as F +import torch.optim as optim +from stable_baselines3.common.buffers import ReplayBuffer +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=1000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=3e-4, + help="the learning rate of the optimizer") + parser.add_argument("--buffer-size", type=int, default=int(1e6), + help="the replay memory buffer size") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--tau", type=float, default=0.005, + help="target smoothing coefficient (default: 0.005)") + parser.add_argument("--batch-size", type=int, default=256, + help="the batch size of sample from the reply memory") + parser.add_argument("--exploration-noise", type=float, default=0.1, + help="the scale of exploration noise") + parser.add_argument("--learning-starts", type=int, default=25e3, + help="timestep to start learning") + parser.add_argument("--policy-frequency", type=int, default=2, + help="the frequency of training policy (delayed)") + parser.add_argument("--noise-clip", type=float, default=0.5, + help="noise clip parameter of the Target Policy Smoothing Regularization") + args = parser.parse_args() + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name): + def thunk(): + env = gym.make(env_id) + # env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +# ALGO LOGIC: initialize agent here: +class QNetwork(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(256)(x) + x = nn.relu(x) + return x + +class Actor(nn.Module): + # state_dim = None + action_dim: Sequence[int] + @nn.compact + def __call__(self, x): + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(self.action_dim)(x) + return x + + +# class QNetwork(nn.Module): +# def __init__(self, env): +# super().__init__() +# self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) +# self.fc2 = nn.Linear(256, 256) +# self.fc3 = nn.Linear(256, 1) + +# def forward(self, x, a): +# x = torch.cat([x, a], 1) +# x = F.relu(self.fc1(x)) +# x = F.relu(self.fc2(x)) +# x = self.fc3(x) +# return x + + +# class Actor(nn.Module): +# def __init__(self, env): +# super().__init__() +# self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) +# self.fc2 = nn.Linear(256, 256) +# self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape)) + +# def forward(self, x): +# x = F.relu(self.fc1(x)) +# x = F.relu(self.fc2(x)) +# return torch.tanh(self.fc_mu(x)) + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + # writer = SummaryWriter(f"runs/{run_name}") + # writer.add_text( + # "hyperparameters", + # "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + # ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + jaxRNG = jax.random.PRNGKey(0) + + + # device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + # envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) + env = gym.make(args.env_id) + # assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + max_action = float(env.action_space.high[0]) + + # actor = Actor(envs).to(device) + + # envs.single_observation_space.dtype = np.float32 + # rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device="cpu") + start_time = time.time() + + # TRY NOT TO MODIFY: start the game + obs = env.reset() + + + + + actor = Actor(action_dim=np.prod(env.action_space.shape)) + actor_parameters = actor.init(jaxRNG, obs) + actor_sample_fn = jax.jit(actor.apply) + # + # print(output) + # qf1 = QNetwork(envs).to(device) + # qf1_target = QNetwork(envs).to(device) + # target_actor = Actor(envs).to(device) + # target_actor.load_state_dict(actor.state_dict()) + # qf1_target.load_state_dict(qf1.state_dict()) + # q_optimizer = optim.Adam(list(qf1.parameters()), lr=args.learning_rate) + actor_optimizer = optax.adam(learning_rate=args.learning_rate) + actor_optimizer_state = actor_optimizer.init(actor_parameters) + + # raise + for global_step in range(args.total_timesteps): + # ALGO LOGIC: put action logic here + if global_step < args.learning_starts: + actions = env.action_space.sample() + else: + actions = actor_sample_fn(actor_parameters, obs) + # actions = np.array( + # [ + # ( + # actions.tolist()[0] + # + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) + # ).clip(envs.single_action_space.low, envs.single_action_space.high) + # ] + # ) + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, dones, infos = env.step(actions) + if dones: + next_obs = env.reset() + + # TRY NOT TO MODIFY: record rewards for plotting purposes + # for info in infos: + # if "episode" in info.keys(): + # print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + # writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + # writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + # break + + # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` + # real_next_obs = next_obs.copy() + # for idx, d in enumerate(dones): + # if d: + # real_next_obs[idx] = infos[idx]["terminal_observation"] + # rb.add(obs, real_next_obs, actions, rewards, dones, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + if global_step % 10000 == 0: + print("SPS:", int(global_step / (time.time() - start_time))) + # writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + # ALGO LOGIC: training. + # if global_step > args.learning_starts: + # data = rb.sample(args.batch_size) + # with torch.no_grad(): + # next_state_actions = (target_actor(data.next_observations)).clamp( + # envs.single_action_space.low[0], envs.single_action_space.high[0] + # ) + # qf1_next_target = qf1_target(data.next_observations, next_state_actions) + # next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (qf1_next_target).view(-1) + + # qf1_a_values = qf1(data.observations, data.actions).view(-1) + # qf1_loss = F.mse_loss(qf1_a_values, next_q_value) + + # # optimize the model + # q_optimizer.zero_grad() + # qf1_loss.backward() + # q_optimizer.step() + + # if global_step % args.policy_frequency == 0: + # actor_loss = -qf1(data.observations, actor(data.observations)).mean() + # actor_optimizer.zero_grad() + # actor_loss.backward() + # actor_optimizer.step() + + # # update the target network + # for param, target_param in zip(actor.parameters(), target_actor.parameters()): + # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + # for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): + # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + + # if global_step % 10000 == 0: + # # writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) + # # writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) + # # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + # print("SPS:", int(global_step / (time.time() - start_time))) + # writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + # writer.close() diff --git a/cleanrl/ddpg_continuous_action_jit.py b/cleanrl/ddpg_continuous_action_jit.py new file mode 100644 index 000000000..60f104255 --- /dev/null +++ b/cleanrl/ddpg_continuous_action_jit.py @@ -0,0 +1,262 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy +# docs and experiment results can be found at +# https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy + +import argparse +import os +import random +import time +from distutils.util import strtobool + +import gym +import numpy as np +import pybullet_envs # noqa +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from stable_baselines3.common.buffers import ReplayBuffer +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="HopperBulletEnv-v0", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=1000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=3e-4, + help="the learning rate of the optimizer") + parser.add_argument("--buffer-size", type=int, default=int(1e6), + help="the replay memory buffer size") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--tau", type=float, default=0.005, + help="target smoothing coefficient (default: 0.005)") + parser.add_argument("--batch-size", type=int, default=256, + help="the batch size of sample from the reply memory") + parser.add_argument("--exploration-noise", type=float, default=0.1, + help="the scale of exploration noise") + parser.add_argument("--learning-starts", type=int, default=500, + help="timestep to start learning") + parser.add_argument("--policy-frequency", type=int, default=2, + help="the frequency of training policy (delayed)") + parser.add_argument("--noise-clip", type=float, default=0.5, + help="noise clip parameter of the Target Policy Smoothing Regularization") + args = parser.parse_args() + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name): + def thunk(): + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +# ALGO LOGIC: initialize agent here: +class QNetwork(nn.Module): + def __init__(self, state_dim, action_dim): + super().__init__() + self.fc1 = nn.Linear(state_dim + action_dim, 256) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 1) + + def forward(self, x, a): + x = torch.cat([x, a], 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class Actor(nn.Module): + def __init__(self, state_dim, action_dim): + super().__init__() + self.fc1 = nn.Linear(state_dim, 256) + self.fc2 = nn.Linear(256, 256) + self.fc_mu = nn.Linear(256, action_dim) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return torch.tanh(self.fc_mu(x)) + + +class Agent(nn.Module): + def __init__(self, state_dim, action_dim): + super().__init__() + self.qf1 = QNetwork(state_dim, action_dim) + self.actor = Actor(state_dim, action_dim) + + self.target_qf1 = QNetwork(state_dim, action_dim) + self.target_qf1.load_state_dict(self.qf1.state_dict()) + self.target_actor = Actor(state_dim, action_dim) + self.target_actor.load_state_dict(self.actor.state_dict()) + + @torch.jit.export + def critic_loss(self, next_observations: torch.Tensor, rewards: torch.Tensor, dones: torch.Tensor, observations: torch.Tensor, actions: torch.Tensor, max_action: float, gamma: float): + with torch.no_grad(): + next_state_actions = (self.target_actor(next_observations)).clamp( + -max_action, max_action + ) + qf1_next_target = self.target_qf1(next_observations, next_state_actions) + next_q_value = rewards.flatten() + (1 - dones.flatten()) * gamma * (qf1_next_target).view(-1) + + qf1_a_values = self.qf1(observations, actions).view(-1) + qf1_loss = F.mse_loss(qf1_a_values, next_q_value) + # print(f"qf1_a_values.sum():{qf1_a_values.sum()}, next_q_value.sum(): {next_q_value.sum()}") + return qf1_loss + + @torch.jit.export + def actor_loss(self, observations: torch.Tensor): + return -self.qf1(observations, self.actor(observations)).mean() + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + max_action = float(envs.single_action_space.high[0]) + state_dim = int(np.array(envs.single_observation_space.shape).prod()) + action_dim = int(np.prod(envs.single_action_space.shape)) + agent = torch.jit.script(Agent(state_dim, action_dim).to(device)) + q_optimizer = optim.Adam(agent.qf1.parameters(), lr=args.learning_rate) + actor_optimizer = optim.Adam(agent.actor.parameters(), lr=args.learning_rate) + + envs.single_observation_space.dtype = np.float32 + rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device) + start_time = time.time() + + # TRY NOT TO MODIFY: start the game + obs = envs.reset() + actor_fn = torch.jit.trace(agent.actor, torch.Tensor(obs).to(device)) + data = None + for global_step in range(args.total_timesteps): + # ALGO LOGIC: put action logic here + if global_step < args.learning_starts: + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + else: + actions = actor_fn(torch.Tensor(obs).to(device)) + actions = np.array( + [ + ( + actions.tolist()[0] + + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) + ).clip(envs.single_action_space.low, envs.single_action_space.high) + ] + ) + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, dones, infos = envs.step(actions) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + for info in infos: + if "episode" in info.keys(): + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + break + + # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` + real_next_obs = next_obs.copy() + for idx, d in enumerate(dones): + if d: + real_next_obs[idx] = infos[idx]["terminal_observation"] + rb.add(obs, real_next_obs, actions, rewards, dones, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + # ALGO LOGIC: training. + + if global_step > args.learning_starts: + if data is None: + data = rb.sample(args.batch_size) + print(data) + qf1_loss = agent.critic_loss(data.next_observations, data.rewards, data.dones, data.observations, data.actions, max_action, args.gamma) + + # optimize the model + q_optimizer.zero_grad() + qf1_loss.backward() + q_optimizer.step() + + + if global_step % args.policy_frequency == 0: + actor_loss = agent.actor_loss(data.observations) + actor_optimizer.zero_grad() + actor_loss.backward() + actor_optimizer.step() + + # update the target network + # for param, target_param in zip(agent.actor.parameters(), agent.target_actor.parameters()): + # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + # for param, target_param in zip(agent.qf1.parameters(), agent.target_qf1.parameters()): + # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + + if global_step % 100 == 0: + # raise + # writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) + # writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) + # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + writer.close() diff --git a/poetry.lock b/poetry.lock index 6797d20bf..8fea25cfd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -331,6 +331,22 @@ python-versions = ">=3.5.0" [package.extras] unicode_backport = ["unicodedata2"] +[[package]] +name = "chex" +version = "0.1.3" +description = "Chex: Testing made fun, in JAX!" +category = "main" +optional = true +python-versions = ">=3.7" + +[package.dependencies] +absl-py = ">=0.9.0" +dm-tree = ">=0.1.5" +jax = ">=0.1.55" +jaxlib = ">=0.1.37" +numpy = ">=1.18.0" +toolz = ">=0.9.0" + [[package]] name = "click" version = "8.0.4" @@ -556,6 +572,33 @@ mccabe = ">=0.6.0,<0.7.0" pycodestyle = ">=2.7.0,<2.8.0" pyflakes = ">=2.3.0,<2.4.0" +[[package]] +name = "flatbuffers" +version = "2.0" +description = "The FlatBuffers serialization format for Python" +category = "main" +optional = true +python-versions = "*" + +[[package]] +name = "flax" +version = "0.4.2" +description = "Flax: A neural network library for JAX designed for flexibility" +category = "main" +optional = true +python-versions = "*" + +[package.dependencies] +jax = ">=0.3" +matplotlib = "*" +msgpack = "*" +numpy = ">=1.12" +optax = "*" +typing-extensions = ">=4.1.1" + +[package.extras] +testing = ["atari-py (==0.2.5)", "clu", "gym (==0.18.3)", "jaxlib", "jraph", "ml-collections", "opencv-python", "pytest", "pytest-cov", "pytest-xdist (==1.34.0)", "pytype", "sentencepiece", "svn", "tensorflow-text (>=2.4.0)", "tensorflow-datasets", "tensorflow", "torch"] + [[package]] name = "fonttools" version = "4.29.1" @@ -941,6 +984,44 @@ requirements_deprecated_finder = ["pipreqs", "pip-api"] colors = ["colorama (>=0.4.3,<0.5.0)"] plugins = ["setuptools"] +[[package]] +name = "jax" +version = "0.3.12" +description = "Differentiate, compile, and transform Numpy code." +category = "main" +optional = true +python-versions = ">=3.7" + +[package.dependencies] +absl-py = "*" +numpy = ">=1.19" +opt_einsum = "*" +scipy = ">=1.2.1" +typing_extensions = "*" + +[package.extras] +ci = ["jaxlib (==0.3.10)"] +cpu = ["jaxlib (==0.3.10)"] +cuda = ["jaxlib (==0.3.10+cuda11.cudnn82)"] +cuda11_cudnn805 = ["jaxlib (==0.3.10+cuda11.cudnn805)"] +cuda11_cudnn82 = ["jaxlib (==0.3.10+cuda11.cudnn82)"] +minimum-jaxlib = ["jaxlib (==0.3.7)"] +tpu = ["jaxlib (==0.3.10)", "libtpu-nightly (==0.1.dev20220504)", "requests"] + +[[package]] +name = "jaxlib" +version = "0.3.10" +description = "XLA library for JAX" +category = "main" +optional = true +python-versions = ">=3.7" + +[package.dependencies] +absl-py = "*" +flatbuffers = ">=1.12,<3.0" +numpy = ">=1.19" +scipy = "*" + [[package]] name = "jedi" version = "0.18.1" @@ -1240,6 +1321,14 @@ category = "main" optional = true python-versions = "*" +[[package]] +name = "msgpack" +version = "1.0.3" +description = "MessagePack (de)serializer." +category = "main" +optional = true +python-versions = "*" + [[package]] name = "mypy-extensions" version = "0.4.3" @@ -1382,6 +1471,37 @@ numpy = [ {version = ">=1.17.3", markers = "python_version >= \"3.8\""}, ] +[[package]] +name = "opt-einsum" +version = "3.3.0" +description = "Optimizing numpys einsum function" +category = "main" +optional = true +python-versions = ">=3.5" + +[package.dependencies] +numpy = ">=1.7" + +[package.extras] +docs = ["sphinx (==1.2.3)", "sphinxcontrib-napoleon", "sphinx-rtd-theme", "numpydoc"] +tests = ["pytest", "pytest-cov", "pytest-pep8"] + +[[package]] +name = "optax" +version = "0.1.2" +description = "A gradient processing and optimisation library in JAX." +category = "main" +optional = true +python-versions = ">=3.7" + +[package.dependencies] +absl-py = ">=0.7.1" +chex = ">=0.0.4" +jax = ">=0.1.55" +jaxlib = ">=0.1.37" +numpy = ">=1.18.0" +typing-extensions = ">=3.10.0" + [[package]] name = "packaging" version = "21.3" @@ -4519,6 +4639,14 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "toolz" +version = "0.11.2" +description = "List processing tools and functional utilities" +category = "main" +optional = true +python-versions = ">=3.5" + [[package]] name = "torch" version = "1.10.2" @@ -4754,6 +4882,7 @@ atari = ["ale-py", "AutoROM"] cloud = ["boto3", "awscli"] docs = ["mkdocs-material"] envpool = ["envpool"] +jax = ["jax", "jaxlib", "flax"] mujoco = ["free-mujoco-py"] pettingzoo = ["pettingzoo", "pygame", "pymunk"] plot = ["pandas", "seaborn"] @@ -4765,7 +4894,7 @@ spyder = ["spyder"] [metadata] lock-version = "1.1" python-versions = ">=3.7.1,<3.10" -content-hash = "ef31d9cebf0312d146bf141368d726815c2f603bc02b3aa112398e940c19891a" +content-hash = "9b49962ded2164aad3c2fd9fe3f65ba8546291ca622083f059cdfdfa58348cc9" [metadata.files] absl-py = [ @@ -4962,6 +5091,10 @@ charset-normalizer = [ {file = "charset-normalizer-2.0.12.tar.gz", hash = "sha256:2857e29ff0d34db842cd7ca3230549d1a697f96ee6d3fb071cfa6c7393832597"}, {file = "charset_normalizer-2.0.12-py3-none-any.whl", hash = "sha256:6881edbebdb17b39b4eaaa821b438bf6eddffb4468cf344f09f89def34a8b1df"}, ] +chex = [ + {file = "chex-0.1.3-py3-none-any.whl", hash = "sha256:5ac1dde599259f9dadc819bcd87b60c8bdfc58d732951eb94a2ba21a1aadb69e"}, + {file = "chex-0.1.3.tar.gz", hash = "sha256:2cfa6ccd02addd6b113658d03bd5ce8a7b3bd24fa62e746a246073414ea1e103"}, +] click = [ {file = "click-8.0.4-py3-none-any.whl", hash = "sha256:6a7a62563bbfabfda3a38f3023a1db4a35978c0abd76f6c9605ecd6554d6d9b1"}, {file = "click-8.0.4.tar.gz", hash = "sha256:8458d7b1287c5fb128c90e23381cf99dcde74beaf6c7ff6384ce84d6fe090adb"}, @@ -5134,6 +5267,14 @@ flake8 = [ {file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"}, {file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"}, ] +flatbuffers = [ + {file = "flatbuffers-2.0-py2.py3-none-any.whl", hash = "sha256:3751954f0604580d3219ae49a85fafec9d85eec599c0b96226e1bc0b48e57474"}, + {file = "flatbuffers-2.0.tar.gz", hash = "sha256:12158ab0272375eab8db2d663ae97370c33f152b27801fa6024e1d6105fd4dd2"}, +] +flax = [ + {file = "flax-0.4.2-py3-none-any.whl", hash = "sha256:37a3293d79cf7b49ca97e2f8fb99b92949013413474e0f059d3b44eec79b5c5a"}, + {file = "flax-0.4.2.tar.gz", hash = "sha256:65d3474570f0e23d9e3cf0650a47edbc937695f276445f877addfaf5ed61913f"}, +] fonttools = [ {file = "fonttools-4.29.1-py3-none-any.whl", hash = "sha256:1933415e0fbdf068815cb1baaa1f159e17830215f7e8624e5731122761627557"}, {file = "fonttools-4.29.1.zip", hash = "sha256:2b18a172120e32128a80efee04cff487d5d140fe7d817deb648b2eee023a40e4"}, @@ -5309,6 +5450,22 @@ isort = [ {file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"}, {file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"}, ] +jax = [ + {file = "jax-0.3.12.tar.gz", hash = "sha256:47dbf66c11d9737d385275d79c2eb40fc6302ab948e5373b10589e4d00ec93bc"}, +] +jaxlib = [ + {file = "jaxlib-0.3.10-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:7b6219b5cd22580ffcc723c789c248f3f62d4a395289308b99a3a9dbc0687bae"}, + {file = "jaxlib-0.3.10-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:cd780a9509461103df8c71473f4e2333853e0f6884ad5a6362757bd42f8a481b"}, + {file = "jaxlib-0.3.10-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:18e2464f7fc19d4996b095f8c981b36e4b5f06558c63b63a9f792208ceb93fee"}, + {file = "jaxlib-0.3.10-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:f5eebd7be2e679a20af0b47596b46cf9cda696684cb70f233dde1c830b9f4f24"}, + {file = "jaxlib-0.3.10-cp37-none-manylinux2014_x86_64.whl", hash = "sha256:153e99306f134f4efc906a9d223aeaf3777b7c46bbe4ba08243b4f63a7cf592f"}, + {file = "jaxlib-0.3.10-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:3dec003496634685581e7039f0bfa2144efac8aa7626668b46e9d8772e5c19f2"}, + {file = "jaxlib-0.3.10-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:a506cd07565683185e985d2f2d71987a3b82cbe75738d78f3118878b2362e429"}, + {file = "jaxlib-0.3.10-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:18e005b96bf18c46bc43b46ca1681aa5a6037ef1c027359a7336eeca4a237bb4"}, + {file = "jaxlib-0.3.10-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:3875edc0ff7d555e9a60adb88fe7fa724d8a997a4d9610badcaa5ebfd4ebbf00"}, + {file = "jaxlib-0.3.10-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:cd30a9af14fdbf675e6672bb1af77be0ce3040b86fed573c5656b8fb4150d550"}, + {file = "jaxlib-0.3.10-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:e67d2f82f543050b682d85bf33a271a6ad365b7f96d98cc86347a2fd636b55ae"}, +] jedi = [ {file = "jedi-0.18.1-py2.py3-none-any.whl", hash = "sha256:637c9635fcf47945ceb91cd7f320234a7be540ded6f3e99a50cb6febdfd1ba8d"}, {file = "jedi-0.18.1.tar.gz", hash = "sha256:74137626a64a99c8eb6ae5832d99b3bdd7d29a3850fe2aa80a4126b2a7d949ab"}, @@ -5578,6 +5735,42 @@ monotonic = [ {file = "monotonic-1.6-py2.py3-none-any.whl", hash = "sha256:68687e19a14f11f26d140dd5c86f3dba4bf5df58003000ed467e0e2a69bca96c"}, {file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"}, ] +msgpack = [ + {file = "msgpack-1.0.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:96acc674bb9c9be63fa8b6dabc3248fdc575c4adc005c440ad02f87ca7edd079"}, + {file = "msgpack-1.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2c3ca57c96c8e69c1a0d2926a6acf2d9a522b41dc4253a8945c4c6cd4981a4e3"}, + {file = "msgpack-1.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0a792c091bac433dfe0a70ac17fc2087d4595ab835b47b89defc8bbabcf5c73"}, + {file = "msgpack-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c58cdec1cb5fcea8c2f1771d7b5fec79307d056874f746690bd2bdd609ab147"}, + {file = "msgpack-1.0.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f97c0f35b3b096a330bb4a1a9247d0bd7e1f3a2eba7ab69795501504b1c2c39"}, + {file = "msgpack-1.0.3-cp310-cp310-win32.whl", hash = "sha256:36a64a10b16c2ab31dcd5f32d9787ed41fe68ab23dd66957ca2826c7f10d0b85"}, + {file = "msgpack-1.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:c1ba333b4024c17c7591f0f372e2daa3c31db495a9b2af3cf664aef3c14354f7"}, + {file = "msgpack-1.0.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:c2140cf7a3ec475ef0938edb6eb363fa704159e0bf71dde15d953bacc1cf9d7d"}, + {file = "msgpack-1.0.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f4c22717c74d44bcd7af353024ce71c6b55346dad5e2cc1ddc17ce8c4507c6b"}, + {file = "msgpack-1.0.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d733a15ade190540c703de209ffbc42a3367600421b62ac0c09fde594da6ec"}, + {file = "msgpack-1.0.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7e03b06f2982aa98d4ddd082a210c3db200471da523f9ac197f2828e80e7770"}, + {file = "msgpack-1.0.3-cp36-cp36m-win32.whl", hash = "sha256:3d875631ecab42f65f9dce6f55ce6d736696ced240f2634633188de2f5f21af9"}, + {file = "msgpack-1.0.3-cp36-cp36m-win_amd64.whl", hash = "sha256:40fb89b4625d12d6027a19f4df18a4de5c64f6f3314325049f219683e07e678a"}, + {file = "msgpack-1.0.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6eef0cf8db3857b2b556213d97dd82de76e28a6524853a9beb3264983391dc1a"}, + {file = "msgpack-1.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d8c332f53ffff01953ad25131272506500b14750c1d0ce8614b17d098252fbc"}, + {file = "msgpack-1.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c0903bd93cbd34653dd63bbfcb99d7539c372795201f39d16fdfde4418de43a"}, + {file = "msgpack-1.0.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bf1e6bfed4860d72106f4e0a1ab519546982b45689937b40257cfd820650b920"}, + {file = "msgpack-1.0.3-cp37-cp37m-win32.whl", hash = "sha256:d02cea2252abc3756b2ac31f781f7a98e89ff9759b2e7450a1c7a0d13302ff50"}, + {file = "msgpack-1.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:2f30dd0dc4dfe6231ad253b6f9f7128ac3202ae49edd3f10d311adc358772dba"}, + {file = "msgpack-1.0.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f201d34dc89342fabb2a10ed7c9a9aaaed9b7af0f16a5923f1ae562b31258dea"}, + {file = "msgpack-1.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bb87f23ae7d14b7b3c21009c4b1705ec107cb21ee71975992f6aca571fb4a42a"}, + {file = "msgpack-1.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a3a5c4b16e9d0edb823fe54b59b5660cc8d4782d7bf2c214cb4b91a1940a8ef"}, + {file = "msgpack-1.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f74da1e5fcf20ade12c6bf1baa17a2dc3604958922de8dc83cbe3eff22e8b611"}, + {file = "msgpack-1.0.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:73a80bd6eb6bcb338c1ec0da273f87420829c266379c8c82fa14c23fb586cfa1"}, + {file = "msgpack-1.0.3-cp38-cp38-win32.whl", hash = "sha256:9fce00156e79af37bb6db4e7587b30d11e7ac6a02cb5bac387f023808cd7d7f4"}, + {file = "msgpack-1.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:9b6f2d714c506e79cbead331de9aae6837c8dd36190d02da74cb409b36162e8a"}, + {file = "msgpack-1.0.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:89908aea5f46ee1474cc37fbc146677f8529ac99201bc2faf4ef8edc023c2bf3"}, + {file = "msgpack-1.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:973ad69fd7e31159eae8f580f3f707b718b61141838321c6fa4d891c4a2cca52"}, + {file = "msgpack-1.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da24375ab4c50e5b7486c115a3198d207954fe10aaa5708f7b65105df09109b2"}, + {file = "msgpack-1.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a598d0685e4ae07a0672b59792d2cc767d09d7a7f39fd9bd37ff84e060b1a996"}, + {file = "msgpack-1.0.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4c309a68cb5d6bbd0c50d5c71a25ae81f268c2dc675c6f4ea8ab2feec2ac4e2"}, + {file = "msgpack-1.0.3-cp39-cp39-win32.whl", hash = "sha256:494471d65b25a8751d19c83f1a482fd411d7ca7a3b9e17d25980a74075ba0e88"}, + {file = "msgpack-1.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:f01b26c2290cbd74316990ba84a14ac3d599af9cebefc543d241a66e785cf17d"}, + {file = "msgpack-1.0.3.tar.gz", hash = "sha256:51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e"}, +] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, @@ -5651,6 +5844,14 @@ opencv-python = [ {file = "opencv_python-3.4.17.61-cp36-abi3-win_amd64.whl", hash = "sha256:0f8d5dc4b23d0f15fa8adac251eff6e4c7ae3b903da3f9ffac8b06162386a90d"}, {file = "opencv_python-3.4.17.61-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:3ccb9b600515485b8dfe95364b0947a43aaad4a17013e9168c962719f5787ae9"}, ] +opt-einsum = [ + {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, + {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, +] +optax = [ + {file = "optax-0.1.2-py3-none-any.whl", hash = "sha256:4e3cb24b70e87acd65700da77c570c468e701d32a2393ae4a5ec35719d90ade6"}, + {file = "optax-0.1.2.tar.gz", hash = "sha256:c2963ffa3b3ac47f72c2866625207c9468558ed18e6e471baac69d4de2ac3f58"}, +] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, @@ -7235,6 +7436,10 @@ tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +toolz = [ + {file = "toolz-0.11.2-py3-none-any.whl", hash = "sha256:a5700ce83414c64514d82d60bcda8aabfde092d1c1a8663f9200c07fdcc6da8f"}, + {file = "toolz-0.11.2.tar.gz", hash = "sha256:6b312d5e15138552f1bda8a4e66c30e236c831b612b2bf0005f8a1df10a4bc33"}, +] torch = [ {file = "torch-1.10.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:8f3fd2e3ffc3bb867133fdf7fbcc8a0bb2e62a5c0696396f51856f5abf9045a8"}, {file = "torch-1.10.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:258a0729fb77a3457d5822d84b536057cd119b08049a8d3c41dc3dcdeb48d56e"}, diff --git a/pyproject.toml b/pyproject.toml index da51634f3..f20c7ef15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ pytest = {version = "^6.2.5", optional = true} free-mujoco-py = {version = "^2.1.6", optional = true} mkdocs-material = {version = "^7.3.4", optional = true} envpool = {version = "^0.4.3", optional = true} +jax = {version = "^0.3.12", optional = true} +jaxlib = {version = "^0.3.10", optional = true} +flax = {version = "^0.4.2", optional = true} [tool.poetry.dev-dependencies] pre-commit = "^2.17.0" @@ -53,3 +56,4 @@ pytest = ["pytest"] mujoco = ["free-mujoco-py"] docs = ["mkdocs-material"] envpool = ["envpool"] +jax = ["jax", "jaxlib", "flax"] From cbc5d88ffd17442b8e74c540ae047b9133dfb3fd Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jun 2022 22:20:45 -0400 Subject: [PATCH 02/27] Quick fix --- cleanrl/ddpg_continuous_action_jax.py | 158 ++++++++++++-------------- cleanrl/ddpg_continuous_action_jit.py | 21 +++- 2 files changed, 90 insertions(+), 89 deletions(-) diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index 04ca399c6..6d4229907 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -9,16 +9,13 @@ from distutils.util import strtobool from typing import Sequence +import flax.linen as nn import gym +import jax import numpy as np +import optax import pybullet_envs # noqa import torch -import jax -import jax.numpy as jnp -import flax.linen as nn -import optax -import torch.nn.functional as F -import torch.optim as optim from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter @@ -74,7 +71,7 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): env = gym.make(env_id) - # env = gym.wrappers.RecordEpisodeStatistics(env) + env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") @@ -98,9 +95,11 @@ def __call__(self, x): x = nn.relu(x) return x + class Actor(nn.Module): # state_dim = None action_dim: Sequence[int] + @nn.compact def __call__(self, x): x = nn.Dense(256)(x) @@ -154,11 +153,11 @@ def __call__(self, x): monitor_gym=True, save_code=True, ) - # writer = SummaryWriter(f"runs/{run_name}") - # writer.add_text( - # "hyperparameters", - # "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - # ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) @@ -167,32 +166,27 @@ def __call__(self, x): torch.backends.cudnn.deterministic = args.torch_deterministic jaxRNG = jax.random.PRNGKey(0) - - # device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - # envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - env = gym.make(args.env_id) - # assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" - max_action = float(env.action_space.high[0]) + max_action = float(envs.single_action_space.high[0]) # actor = Actor(envs).to(device) - # envs.single_observation_space.dtype = np.float32 - # rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device="cpu") + envs.single_observation_space.dtype = np.float32 + rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device="cpu") start_time = time.time() # TRY NOT TO MODIFY: start the game - obs = env.reset() - - - + obs = envs.reset() - actor = Actor(action_dim=np.prod(env.action_space.shape)) + actor = Actor(action_dim=np.prod(envs.single_observation_space)) actor_parameters = actor.init(jaxRNG, obs) actor_sample_fn = jax.jit(actor.apply) - # + # # print(output) # qf1 = QNetwork(envs).to(device) # qf1_target = QNetwork(envs).to(device) @@ -207,37 +201,35 @@ def __call__(self, x): for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = env.action_space.sample() + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: actions = actor_sample_fn(actor_parameters, obs) - # actions = np.array( - # [ - # ( - # actions.tolist()[0] - # + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) - # ).clip(envs.single_action_space.low, envs.single_action_space.high) - # ] - # ) + actions = np.array( + [ + ( + actions.tolist()[0] + + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) + ).clip(envs.single_action_space.low, envs.single_action_space.high) + ] + ) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, dones, infos = env.step(actions) - if dones: - next_obs = env.reset() + next_obs, rewards, dones, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes - # for info in infos: - # if "episode" in info.keys(): - # print(f"global_step={global_step}, episodic_return={info['episode']['r']}") - # writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) - # writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) - # break + for info in infos: + if "episode" in info.keys(): + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` - # real_next_obs = next_obs.copy() - # for idx, d in enumerate(dones): - # if d: - # real_next_obs[idx] = infos[idx]["terminal_observation"] - # rb.add(obs, real_next_obs, actions, rewards, dones, infos) + real_next_obs = next_obs.copy() + for idx, d in enumerate(dones): + if d: + real_next_obs[idx] = infos[idx]["terminal_observation"] + rb.add(obs, real_next_obs, actions, rewards, dones, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs @@ -248,40 +240,40 @@ def __call__(self, x): # ALGO LOGIC: training. # if global_step > args.learning_starts: - # data = rb.sample(args.batch_size) - # with torch.no_grad(): - # next_state_actions = (target_actor(data.next_observations)).clamp( - # envs.single_action_space.low[0], envs.single_action_space.high[0] - # ) - # qf1_next_target = qf1_target(data.next_observations, next_state_actions) - # next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (qf1_next_target).view(-1) - - # qf1_a_values = qf1(data.observations, data.actions).view(-1) - # qf1_loss = F.mse_loss(qf1_a_values, next_q_value) - - # # optimize the model - # q_optimizer.zero_grad() - # qf1_loss.backward() - # q_optimizer.step() - - # if global_step % args.policy_frequency == 0: - # actor_loss = -qf1(data.observations, actor(data.observations)).mean() - # actor_optimizer.zero_grad() - # actor_loss.backward() - # actor_optimizer.step() - - # # update the target network - # for param, target_param in zip(actor.parameters(), target_actor.parameters()): - # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - # for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): - # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - - # if global_step % 10000 == 0: - # # writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) - # # writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) - # # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) - # print("SPS:", int(global_step / (time.time() - start_time))) - # writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + # data = rb.sample(args.batch_size) + # with torch.no_grad(): + # next_state_actions = (target_actor(data.next_observations)).clamp( + # envs.single_action_space.low[0], envs.single_action_space.high[0] + # ) + # qf1_next_target = qf1_target(data.next_observations, next_state_actions) + # next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (qf1_next_target).view(-1) + + # qf1_a_values = qf1(data.observations, data.actions).view(-1) + # qf1_loss = F.mse_loss(qf1_a_values, next_q_value) + + # # optimize the model + # q_optimizer.zero_grad() + # qf1_loss.backward() + # q_optimizer.step() + + # if global_step % args.policy_frequency == 0: + # actor_loss = -qf1(data.observations, actor(data.observations)).mean() + # actor_optimizer.zero_grad() + # actor_loss.backward() + # actor_optimizer.step() + + # # update the target network + # for param, target_param in zip(actor.parameters(), target_actor.parameters()): + # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + # for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): + # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + + # if global_step % 10000 == 0: + # # writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) + # # writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) + # # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + # print("SPS:", int(global_step / (time.time() - start_time))) + # writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() # writer.close() diff --git a/cleanrl/ddpg_continuous_action_jit.py b/cleanrl/ddpg_continuous_action_jit.py index 60f104255..e7cbe2aee 100644 --- a/cleanrl/ddpg_continuous_action_jit.py +++ b/cleanrl/ddpg_continuous_action_jit.py @@ -123,11 +123,18 @@ def __init__(self, state_dim, action_dim): self.target_actor.load_state_dict(self.actor.state_dict()) @torch.jit.export - def critic_loss(self, next_observations: torch.Tensor, rewards: torch.Tensor, dones: torch.Tensor, observations: torch.Tensor, actions: torch.Tensor, max_action: float, gamma: float): + def critic_loss( + self, + next_observations: torch.Tensor, + rewards: torch.Tensor, + dones: torch.Tensor, + observations: torch.Tensor, + actions: torch.Tensor, + max_action: float, + gamma: float, + ): with torch.no_grad(): - next_state_actions = (self.target_actor(next_observations)).clamp( - -max_action, max_action - ) + next_state_actions = (self.target_actor(next_observations)).clamp(-max_action, max_action) qf1_next_target = self.target_qf1(next_observations, next_state_actions) next_q_value = rewards.flatten() + (1 - dones.flatten()) * gamma * (qf1_next_target).view(-1) @@ -140,6 +147,7 @@ def critic_loss(self, next_observations: torch.Tensor, rewards: torch.Tensor, do def actor_loss(self, observations: torch.Tensor): return -self.qf1(observations, self.actor(observations)).mean() + if __name__ == "__main__": args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" @@ -230,13 +238,14 @@ def actor_loss(self, observations: torch.Tensor): if data is None: data = rb.sample(args.batch_size) print(data) - qf1_loss = agent.critic_loss(data.next_observations, data.rewards, data.dones, data.observations, data.actions, max_action, args.gamma) + qf1_loss = agent.critic_loss( + data.next_observations, data.rewards, data.dones, data.observations, data.actions, max_action, args.gamma + ) # optimize the model q_optimizer.zero_grad() qf1_loss.backward() q_optimizer.step() - if global_step % args.policy_frequency == 0: actor_loss = agent.actor_loss(data.observations) From b4662c224b2840d7b7aff8b51d37aa9b58fc31dc Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 22 Jun 2022 19:14:49 -0400 Subject: [PATCH 03/27] quick fix --- cleanrl/dd.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cleanrl/dd.py b/cleanrl/dd.py index 67c313cf7..80a326f80 100644 --- a/cleanrl/dd.py +++ b/cleanrl/dd.py @@ -1,23 +1,24 @@ from typing import Sequence -import numpy as np +import flax.linen as nn import jax import jax.numpy as jnp -import flax.linen as nn + class MLP(nn.Module): - features: Sequence[int] + features: Sequence[int] + + @nn.compact + def __call__(self, x): + for feat in self.features[:-1]: + x = nn.relu(nn.Dense(feat)(x)) + x = nn.Dense(self.features[-1])(x) + return x - @nn.compact - def __call__(self, x): - for feat in self.features[:-1]: - x = nn.relu(nn.Dense(feat)(x)) - x = nn.Dense(self.features[-1])(x) - return x model = MLP([12, 8, 4]) batch = jnp.ones((32, 10)) variables = model.init(jax.random.PRNGKey(0), batch) # print(variables) # for _ in range(4000): -# output = +# output = From 754a0b12b55cfb2ce48c6ded0cb7e73ce3140724 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 24 Jun 2022 18:31:05 -0400 Subject: [PATCH 04/27] Commit changes - successful prototype --- cleanrl/ddpg_continuous_action_jax.py | 245 ++++++++++++++++---------- 1 file changed, 156 insertions(+), 89 deletions(-) diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index 6d4229907..75ad676d1 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -1,20 +1,18 @@ -# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy -# docs and experiment results can be found at -# https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy - import argparse +import functools import os import random import time +import copy from distutils.util import strtobool from typing import Sequence import flax.linen as nn import gym import jax +import jax.numpy as jnp import numpy as np import optax -import pybullet_envs # noqa import torch from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter @@ -85,19 +83,20 @@ def thunk(): # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): + obs_dim: Sequence[int] + action_dim: Sequence[int] @nn.compact - def __call__(self, x): - x = nn.Dense(256)(x) - x = nn.relu(x) + def __call__(self, x: jnp.ndarray, a: jnp.ndarray): + x = jnp.concatenate([x, a], -1) x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(256)(x) x = nn.relu(x) + x = nn.Dense(1)(x) return x class Actor(nn.Module): - # state_dim = None action_dim: Sequence[int] @nn.compact @@ -107,36 +106,90 @@ def __call__(self, x): x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(self.action_dim)(x) + x = nn.tanh(x) return x -# class QNetwork(nn.Module): -# def __init__(self, env): -# super().__init__() -# self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) -# self.fc2 = nn.Linear(256, 256) -# self.fc3 = nn.Linear(256, 1) - -# def forward(self, x, a): -# x = torch.cat([x, a], 1) -# x = F.relu(self.fc1(x)) -# x = F.relu(self.fc2(x)) -# x = self.fc3(x) -# return x - - -# class Actor(nn.Module): -# def __init__(self, env): -# super().__init__() -# self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) -# self.fc2 = nn.Linear(256, 256) -# self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape)) - -# def forward(self, x): -# x = F.relu(self.fc1(x)) -# x = F.relu(self.fc2(x)) -# return torch.tanh(self.fc_mu(x)) - +@functools.partial(jax.jit, static_argnames=('actor', 'qf1', 'qf1', 'qf1_optimizer', 'actor_optimizer')) +def forward( + actor, + actor_parameters, + actor_target_parameters, + qf1, + qf1_parameters, + qf1_target_parameters, + observations, + actions, + next_observations, + rewards, + dones, + gamma, + tau, + qf1_optimizer, + qf1_optimizer_state, + actor_optimizer, + actor_optimizer_state, +): + next_state_actions = (actor.apply(actor_target_parameters, next_observations)).clip(-1, 1) + qf1_next_target = qf1.apply(qf1_target_parameters, next_observations, next_state_actions).reshape(-1) + next_q_value = (rewards + (1 - dones) * gamma * (qf1_next_target)).reshape(-1) + + # def mse_loss(qf1_parameters, observations, actions, next_q_value): + # return ((qf1.apply(qf1_parameters, observations, actions) - next_q_value) ** 2).mean() + + @jax.jit + def mse_loss(qf1_parameters, observations, actions, next_q_value): + # Define the squared loss for a single pair (x,y) + def squared_error(x, a, y): + pred = qf1.apply(qf1_parameters, x, a) + return jnp.inner(y-pred, y-pred) / 2.0 + # Vectorize the previous to compute the average of the loss on all samples. + return jnp.mean(jax.vmap(squared_error)(observations,actions, next_q_value), axis=0) + + qf1_loss_value, grads = jax.value_and_grad(mse_loss)(qf1_parameters, observations, actions, next_q_value) + updates, qf1_optimizer_state = qf1_optimizer.update(grads, qf1_optimizer_state) + qf1_parameters = optax.apply_updates(qf1_parameters, updates) + + return qf1_loss_value, 0, qf1_parameters, qf1_target_parameters, qf1_optimizer_state, actor_parameters, actor_target_parameters, actor_optimizer_state + + +@functools.partial(jax.jit, static_argnames=('actor', 'qf1', 'qf1', 'qf1_optimizer', 'actor_optimizer')) +def forward2( + actor, + actor_parameters, + actor_target_parameters, + qf1, + qf1_parameters, + qf1_target_parameters, + observations, + actions, + next_observations, + rewards, + dones, + gamma, + tau, + qf1_optimizer, + qf1_optimizer_state, + actor_optimizer, + actor_optimizer_state, +): + def actor_loss(actor_parameters, qf1_parameters, observations): + return -qf1.apply(qf1_parameters, observations, actor.apply(actor_parameters, observations)).mean() + + actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_parameters, qf1_parameters, observations) + updates, actor_optimizer_state = actor_optimizer.update(grads, actor_optimizer_state) + actor_parameters = optax.apply_updates(actor_parameters, updates) + + actor_target_parameters = update_target(actor_parameters, actor_target_parameters, tau) + qf1_target_parameters = update_target(qf1_parameters, qf1_target_parameters, tau) + + return 0, actor_loss_value, qf1_parameters, qf1_target_parameters, qf1_optimizer_state, actor_parameters, actor_target_parameters, actor_optimizer_state + + +def update_target(src, dst, tau): + return jax.tree_map( + lambda p, tp: p * tau + tp * (1 - tau), src, dst + ) if __name__ == "__main__": args = parse_args() @@ -183,19 +236,20 @@ def __call__(self, x): # TRY NOT TO MODIFY: start the game obs = envs.reset() - actor = Actor(action_dim=np.prod(envs.single_observation_space)) + actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) actor_parameters = actor.init(jaxRNG, obs) - actor_sample_fn = jax.jit(actor.apply) - # - # print(output) - # qf1 = QNetwork(envs).to(device) - # qf1_target = QNetwork(envs).to(device) - # target_actor = Actor(envs).to(device) - # target_actor.load_state_dict(actor.state_dict()) - # qf1_target.load_state_dict(qf1.state_dict()) - # q_optimizer = optim.Adam(list(qf1.parameters()), lr=args.learning_rate) + actor_target_parameters = actor.init(jaxRNG, obs) + actor.apply = jax.jit(actor.apply) + qf1 = QNetwork(obs_dim=np.prod(envs.single_observation_space.shape), action_dim=np.prod(envs.single_action_space.shape)) + qf1_parameters = qf1.init(jaxRNG, obs, envs.action_space.sample()) + qf1_target_parameters = qf1.init(jaxRNG, obs, envs.action_space.sample()) + qf1.apply = jax.jit(qf1.apply) + actor_target_parameters = update_target(actor_parameters, actor_target_parameters, 1.0) + qf1_target_parameters = update_target(qf1_parameters, qf1_target_parameters, 1.0) actor_optimizer = optax.adam(learning_rate=args.learning_rate) actor_optimizer_state = actor_optimizer.init(actor_parameters) + qf1_optimizer = optax.adam(learning_rate=args.learning_rate) + qf1_optimizer_state = qf1_optimizer.init(qf1_parameters) # raise for global_step in range(args.total_timesteps): @@ -203,11 +257,11 @@ def __call__(self, x): if global_step < args.learning_starts: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: - actions = actor_sample_fn(actor_parameters, obs) + actions = actor.apply(actor_parameters, obs) actions = np.array( [ ( - actions.tolist()[0] + np.array(actions)[0] + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) ).clip(envs.single_action_space.low, envs.single_action_space.high) ] @@ -234,46 +288,59 @@ def __call__(self, x): # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs - if global_step % 10000 == 0: - print("SPS:", int(global_step / (time.time() - start_time))) - # writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - - # ALGO LOGIC: training. - # if global_step > args.learning_starts: - # data = rb.sample(args.batch_size) - # with torch.no_grad(): - # next_state_actions = (target_actor(data.next_observations)).clamp( - # envs.single_action_space.low[0], envs.single_action_space.high[0] - # ) - # qf1_next_target = qf1_target(data.next_observations, next_state_actions) - # next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (qf1_next_target).view(-1) - - # qf1_a_values = qf1(data.observations, data.actions).view(-1) - # qf1_loss = F.mse_loss(qf1_a_values, next_q_value) - - # # optimize the model - # q_optimizer.zero_grad() - # qf1_loss.backward() - # q_optimizer.step() - - # if global_step % args.policy_frequency == 0: - # actor_loss = -qf1(data.observations, actor(data.observations)).mean() - # actor_optimizer.zero_grad() - # actor_loss.backward() - # actor_optimizer.step() - - # # update the target network - # for param, target_param in zip(actor.parameters(), target_actor.parameters()): - # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - # for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): - # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - - # if global_step % 10000 == 0: - # # writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) - # # writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) - # # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) - # print("SPS:", int(global_step / (time.time() - start_time))) - # writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + # # ALGO LOGIC: training. + if global_step > args.learning_starts: + data = rb.sample(args.batch_size) + qf1_loss_value, _, qf1_parameters, qf1_target_parameters, qf1_optimizer_state, actor_parameters, actor_target_parameters, actor_optimizer_state = forward( + actor, + actor_parameters, + actor_target_parameters, + qf1, + qf1_parameters, + qf1_target_parameters, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + args.gamma, + args.tau, + qf1_optimizer, + qf1_optimizer_state, + actor_optimizer, + actor_optimizer_state, + ) + + if global_step % args.policy_frequency == 0: + _, actor_loss_value, qf1_parameters, qf1_target_parameters, qf1_optimizer_state, actor_parameters, actor_target_parameters, actor_optimizer_state = forward2( + actor, + actor_parameters, + actor_target_parameters, + qf1, + qf1_parameters, + qf1_target_parameters, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + args.gamma, + args.tau, + qf1_optimizer, + qf1_optimizer_state, + actor_optimizer, + actor_optimizer_state, + ) + + + # print(actor_parameters["params"]["Dense_0"]["kernel"].sum()) + if global_step % 100 == 0: + # print(qf1_target_parameters["params"]["Dense_0"]["kernel"].sum()) + writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) + writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) + # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() - # writer.close() + writer.close() From 223a8ffa6a180e2aa9a7527db5cc9f05e2e2580d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 24 Jun 2022 21:37:26 -0400 Subject: [PATCH 05/27] Remove scripts --- cleanrl/dd.py | 24 --- cleanrl/ddpg_continuous_action_jit.py | 271 -------------------------- 2 files changed, 295 deletions(-) delete mode 100644 cleanrl/dd.py delete mode 100644 cleanrl/ddpg_continuous_action_jit.py diff --git a/cleanrl/dd.py b/cleanrl/dd.py deleted file mode 100644 index 80a326f80..000000000 --- a/cleanrl/dd.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Sequence - -import flax.linen as nn -import jax -import jax.numpy as jnp - - -class MLP(nn.Module): - features: Sequence[int] - - @nn.compact - def __call__(self, x): - for feat in self.features[:-1]: - x = nn.relu(nn.Dense(feat)(x)) - x = nn.Dense(self.features[-1])(x) - return x - - -model = MLP([12, 8, 4]) -batch = jnp.ones((32, 10)) -variables = model.init(jax.random.PRNGKey(0), batch) -# print(variables) -# for _ in range(4000): -# output = diff --git a/cleanrl/ddpg_continuous_action_jit.py b/cleanrl/ddpg_continuous_action_jit.py deleted file mode 100644 index e7cbe2aee..000000000 --- a/cleanrl/ddpg_continuous_action_jit.py +++ /dev/null @@ -1,271 +0,0 @@ -# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy -# docs and experiment results can be found at -# https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy - -import argparse -import os -import random -import time -from distutils.util import strtobool - -import gym -import numpy as np -import pybullet_envs # noqa -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from stable_baselines3.common.buffers import ReplayBuffer -from torch.utils.tensorboard import SummaryWriter - - -def parse_args(): - # fmt: off - parser = argparse.ArgumentParser() - parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), - help="the name of this experiment") - parser.add_argument("--seed", type=int, default=1, - help="seed of the experiment") - parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, `torch.backends.cudnn.deterministic=False`") - parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, cuda will be enabled by default") - parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="if toggled, this experiment will be tracked with Weights and Biases") - parser.add_argument("--wandb-project-name", type=str, default="cleanRL", - help="the wandb's project name") - parser.add_argument("--wandb-entity", type=str, default=None, - help="the entity (team) of wandb's project") - parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="weather to capture videos of the agent performances (check out `videos` folder)") - - # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="HopperBulletEnv-v0", - help="the id of the environment") - parser.add_argument("--total-timesteps", type=int, default=1000000, - help="total timesteps of the experiments") - parser.add_argument("--learning-rate", type=float, default=3e-4, - help="the learning rate of the optimizer") - parser.add_argument("--buffer-size", type=int, default=int(1e6), - help="the replay memory buffer size") - parser.add_argument("--gamma", type=float, default=0.99, - help="the discount factor gamma") - parser.add_argument("--tau", type=float, default=0.005, - help="target smoothing coefficient (default: 0.005)") - parser.add_argument("--batch-size", type=int, default=256, - help="the batch size of sample from the reply memory") - parser.add_argument("--exploration-noise", type=float, default=0.1, - help="the scale of exploration noise") - parser.add_argument("--learning-starts", type=int, default=500, - help="timestep to start learning") - parser.add_argument("--policy-frequency", type=int, default=2, - help="the frequency of training policy (delayed)") - parser.add_argument("--noise-clip", type=float, default=0.5, - help="noise clip parameter of the Target Policy Smoothing Regularization") - args = parser.parse_args() - # fmt: on - return args - - -def make_env(env_id, seed, idx, capture_video, run_name): - def thunk(): - env = gym.make(env_id) - env = gym.wrappers.RecordEpisodeStatistics(env) - if capture_video: - if idx == 0: - env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") - env.seed(seed) - env.action_space.seed(seed) - env.observation_space.seed(seed) - return env - - return thunk - - -# ALGO LOGIC: initialize agent here: -class QNetwork(nn.Module): - def __init__(self, state_dim, action_dim): - super().__init__() - self.fc1 = nn.Linear(state_dim + action_dim, 256) - self.fc2 = nn.Linear(256, 256) - self.fc3 = nn.Linear(256, 1) - - def forward(self, x, a): - x = torch.cat([x, a], 1) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -class Actor(nn.Module): - def __init__(self, state_dim, action_dim): - super().__init__() - self.fc1 = nn.Linear(state_dim, 256) - self.fc2 = nn.Linear(256, 256) - self.fc_mu = nn.Linear(256, action_dim) - - def forward(self, x): - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - return torch.tanh(self.fc_mu(x)) - - -class Agent(nn.Module): - def __init__(self, state_dim, action_dim): - super().__init__() - self.qf1 = QNetwork(state_dim, action_dim) - self.actor = Actor(state_dim, action_dim) - - self.target_qf1 = QNetwork(state_dim, action_dim) - self.target_qf1.load_state_dict(self.qf1.state_dict()) - self.target_actor = Actor(state_dim, action_dim) - self.target_actor.load_state_dict(self.actor.state_dict()) - - @torch.jit.export - def critic_loss( - self, - next_observations: torch.Tensor, - rewards: torch.Tensor, - dones: torch.Tensor, - observations: torch.Tensor, - actions: torch.Tensor, - max_action: float, - gamma: float, - ): - with torch.no_grad(): - next_state_actions = (self.target_actor(next_observations)).clamp(-max_action, max_action) - qf1_next_target = self.target_qf1(next_observations, next_state_actions) - next_q_value = rewards.flatten() + (1 - dones.flatten()) * gamma * (qf1_next_target).view(-1) - - qf1_a_values = self.qf1(observations, actions).view(-1) - qf1_loss = F.mse_loss(qf1_a_values, next_q_value) - # print(f"qf1_a_values.sum():{qf1_a_values.sum()}, next_q_value.sum(): {next_q_value.sum()}") - return qf1_loss - - @torch.jit.export - def actor_loss(self, observations: torch.Tensor): - return -self.qf1(observations, self.actor(observations)).mean() - - -if __name__ == "__main__": - args = parse_args() - run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=vars(args), - name=run_name, - monitor_gym=True, - save_code=True, - ) - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - - # TRY NOT TO MODIFY: seeding - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.backends.cudnn.deterministic = args.torch_deterministic - - device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") - - # env setup - envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" - - max_action = float(envs.single_action_space.high[0]) - state_dim = int(np.array(envs.single_observation_space.shape).prod()) - action_dim = int(np.prod(envs.single_action_space.shape)) - agent = torch.jit.script(Agent(state_dim, action_dim).to(device)) - q_optimizer = optim.Adam(agent.qf1.parameters(), lr=args.learning_rate) - actor_optimizer = optim.Adam(agent.actor.parameters(), lr=args.learning_rate) - - envs.single_observation_space.dtype = np.float32 - rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device) - start_time = time.time() - - # TRY NOT TO MODIFY: start the game - obs = envs.reset() - actor_fn = torch.jit.trace(agent.actor, torch.Tensor(obs).to(device)) - data = None - for global_step in range(args.total_timesteps): - # ALGO LOGIC: put action logic here - if global_step < args.learning_starts: - actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) - else: - actions = actor_fn(torch.Tensor(obs).to(device)) - actions = np.array( - [ - ( - actions.tolist()[0] - + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) - ).clip(envs.single_action_space.low, envs.single_action_space.high) - ] - ) - - # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, dones, infos = envs.step(actions) - - # TRY NOT TO MODIFY: record rewards for plotting purposes - for info in infos: - if "episode" in info.keys(): - print(f"global_step={global_step}, episodic_return={info['episode']['r']}") - writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) - writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) - break - - # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` - real_next_obs = next_obs.copy() - for idx, d in enumerate(dones): - if d: - real_next_obs[idx] = infos[idx]["terminal_observation"] - rb.add(obs, real_next_obs, actions, rewards, dones, infos) - - # TRY NOT TO MODIFY: CRUCIAL step easy to overlook - obs = next_obs - - # ALGO LOGIC: training. - - if global_step > args.learning_starts: - if data is None: - data = rb.sample(args.batch_size) - print(data) - qf1_loss = agent.critic_loss( - data.next_observations, data.rewards, data.dones, data.observations, data.actions, max_action, args.gamma - ) - - # optimize the model - q_optimizer.zero_grad() - qf1_loss.backward() - q_optimizer.step() - - if global_step % args.policy_frequency == 0: - actor_loss = agent.actor_loss(data.observations) - actor_optimizer.zero_grad() - actor_loss.backward() - actor_optimizer.step() - - # update the target network - # for param, target_param in zip(agent.actor.parameters(), agent.target_actor.parameters()): - # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - # for param, target_param in zip(agent.qf1.parameters(), agent.target_qf1.parameters()): - # target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - - if global_step % 100 == 0: - # raise - # writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) - # writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) - # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - - envs.close() - writer.close() From 85fbfe2af0c3c044f4cb1199f2fb6f9ea262131f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 24 Jun 2022 22:05:49 -0400 Subject: [PATCH 06/27] Simplify the implementation: careful with shape --- cleanrl/ddpg_continuous_action_jax.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index 75ad676d1..b0dc89783 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -134,17 +134,8 @@ def forward( qf1_next_target = qf1.apply(qf1_target_parameters, next_observations, next_state_actions).reshape(-1) next_q_value = (rewards + (1 - dones) * gamma * (qf1_next_target)).reshape(-1) - # def mse_loss(qf1_parameters, observations, actions, next_q_value): - # return ((qf1.apply(qf1_parameters, observations, actions) - next_q_value) ** 2).mean() - - @jax.jit def mse_loss(qf1_parameters, observations, actions, next_q_value): - # Define the squared loss for a single pair (x,y) - def squared_error(x, a, y): - pred = qf1.apply(qf1_parameters, x, a) - return jnp.inner(y-pred, y-pred) / 2.0 - # Vectorize the previous to compute the average of the loss on all samples. - return jnp.mean(jax.vmap(squared_error)(observations,actions, next_q_value), axis=0) + return ((qf1.apply(qf1_parameters, observations, actions).squeeze() - next_q_value) ** 2).mean() qf1_loss_value, grads = jax.value_and_grad(mse_loss)(qf1_parameters, observations, actions, next_q_value) updates, qf1_optimizer_state = qf1_optimizer.update(grads, qf1_optimizer_state) From 8ffbd26cf1065ba200341787d975022d441bee93 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 24 Jun 2022 22:50:01 -0400 Subject: [PATCH 07/27] Format --- cleanrl/ddpg_continuous_action_jax.py | 161 ++++++++++---------------- 1 file changed, 59 insertions(+), 102 deletions(-) diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index b0dc89783..48649ad11 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -1,9 +1,7 @@ import argparse -import functools import os import random import time -import copy from distutils.util import strtobool from typing import Sequence @@ -85,6 +83,7 @@ def thunk(): class QNetwork(nn.Module): obs_dim: Sequence[int] action_dim: Sequence[int] + @nn.compact def __call__(self, x: jnp.ndarray, a: jnp.ndarray): x = jnp.concatenate([x, a], -1) @@ -110,77 +109,9 @@ def __call__(self, x): return x -@functools.partial(jax.jit, static_argnames=('actor', 'qf1', 'qf1', 'qf1_optimizer', 'actor_optimizer')) -def forward( - actor, - actor_parameters, - actor_target_parameters, - qf1, - qf1_parameters, - qf1_target_parameters, - observations, - actions, - next_observations, - rewards, - dones, - gamma, - tau, - qf1_optimizer, - qf1_optimizer_state, - actor_optimizer, - actor_optimizer_state, -): - next_state_actions = (actor.apply(actor_target_parameters, next_observations)).clip(-1, 1) - qf1_next_target = qf1.apply(qf1_target_parameters, next_observations, next_state_actions).reshape(-1) - next_q_value = (rewards + (1 - dones) * gamma * (qf1_next_target)).reshape(-1) - - def mse_loss(qf1_parameters, observations, actions, next_q_value): - return ((qf1.apply(qf1_parameters, observations, actions).squeeze() - next_q_value) ** 2).mean() - - qf1_loss_value, grads = jax.value_and_grad(mse_loss)(qf1_parameters, observations, actions, next_q_value) - updates, qf1_optimizer_state = qf1_optimizer.update(grads, qf1_optimizer_state) - qf1_parameters = optax.apply_updates(qf1_parameters, updates) - - return qf1_loss_value, 0, qf1_parameters, qf1_target_parameters, qf1_optimizer_state, actor_parameters, actor_target_parameters, actor_optimizer_state - - -@functools.partial(jax.jit, static_argnames=('actor', 'qf1', 'qf1', 'qf1_optimizer', 'actor_optimizer')) -def forward2( - actor, - actor_parameters, - actor_target_parameters, - qf1, - qf1_parameters, - qf1_target_parameters, - observations, - actions, - next_observations, - rewards, - dones, - gamma, - tau, - qf1_optimizer, - qf1_optimizer_state, - actor_optimizer, - actor_optimizer_state, -): - def actor_loss(actor_parameters, qf1_parameters, observations): - return -qf1.apply(qf1_parameters, observations, actor.apply(actor_parameters, observations)).mean() - - actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_parameters, qf1_parameters, observations) - updates, actor_optimizer_state = actor_optimizer.update(grads, actor_optimizer_state) - actor_parameters = optax.apply_updates(actor_parameters, updates) - - actor_target_parameters = update_target(actor_parameters, actor_target_parameters, tau) - qf1_target_parameters = update_target(qf1_parameters, qf1_target_parameters, tau) - - return 0, actor_loss_value, qf1_parameters, qf1_target_parameters, qf1_optimizer_state, actor_parameters, actor_target_parameters, actor_optimizer_state - - def update_target(src, dst, tau): - return jax.tree_map( - lambda p, tp: p * tau + tp * (1 - tau), src, dst - ) + return jax.tree_map(lambda p, tp: p * tau + tp * (1 - tau), src, dst) + if __name__ == "__main__": args = parse_args() @@ -242,7 +173,50 @@ def update_target(src, dst, tau): qf1_optimizer = optax.adam(learning_rate=args.learning_rate) qf1_optimizer_state = qf1_optimizer.init(qf1_parameters) - # raise + @jax.jit + def update_critic( + observations, + actions, + next_observations, + rewards, + dones, + actor_target_parameters, + qf1_parameters, + qf1_target_parameters, + qf1_optimizer_state, + ): + next_state_actions = (actor.apply(actor_target_parameters, next_observations)).clip(-1, 1) + qf1_next_target = qf1.apply(qf1_target_parameters, next_observations, next_state_actions).reshape(-1) + next_q_value = (rewards + (1 - dones) * args.gamma * (qf1_next_target)).reshape(-1) + + def mse_loss(qf1_parameters, observations, actions, next_q_value): + return ((qf1.apply(qf1_parameters, observations, actions).squeeze() - next_q_value) ** 2).mean() + + qf1_loss_value, grads = jax.value_and_grad(mse_loss)(qf1_parameters, observations, actions, next_q_value) + updates, qf1_optimizer_state = qf1_optimizer.update(grads, qf1_optimizer_state) + qf1_parameters = optax.apply_updates(qf1_parameters, updates) + return qf1_loss_value, qf1_parameters, qf1_optimizer_state + + @jax.jit + def update_actor( + observations, + actor_parameters, + actor_target_parameters, + qf1_parameters, + qf1_target_parameters, + actor_optimizer_state, + ): + def actor_loss(actor_parameters, qf1_parameters, observations): + return -qf1.apply(qf1_parameters, observations, actor.apply(actor_parameters, observations)).mean() + + actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_parameters, qf1_parameters, observations) + updates, actor_optimizer_state = actor_optimizer.update(grads, actor_optimizer_state) + actor_parameters = optax.apply_updates(actor_parameters, updates) + + actor_target_parameters = update_target(actor_parameters, actor_target_parameters, args.tau) + qf1_target_parameters = update_target(qf1_parameters, qf1_target_parameters, args.tau) + return actor_loss_value, actor_parameters, actor_optimizer_state, actor_target_parameters, qf1_target_parameters + for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: @@ -282,51 +256,34 @@ def update_target(src, dst, tau): # # ALGO LOGIC: training. if global_step > args.learning_starts: data = rb.sample(args.batch_size) - qf1_loss_value, _, qf1_parameters, qf1_target_parameters, qf1_optimizer_state, actor_parameters, actor_target_parameters, actor_optimizer_state = forward( - actor, - actor_parameters, - actor_target_parameters, - qf1, - qf1_parameters, - qf1_target_parameters, + qf1_loss_value, qf1_parameters, qf1_optimizer_state = update_critic( data.observations.numpy(), data.actions.numpy(), data.next_observations.numpy(), data.rewards.flatten().numpy(), data.dones.flatten().numpy(), - args.gamma, - args.tau, - qf1_optimizer, + actor_target_parameters, + qf1_parameters, + qf1_target_parameters, qf1_optimizer_state, - actor_optimizer, - actor_optimizer_state, ) - if global_step % args.policy_frequency == 0: - _, actor_loss_value, qf1_parameters, qf1_target_parameters, qf1_optimizer_state, actor_parameters, actor_target_parameters, actor_optimizer_state = forward2( - actor, + ( + actor_loss_value, actor_parameters, + actor_optimizer_state, actor_target_parameters, - qf1, - qf1_parameters, qf1_target_parameters, + ) = update_actor( data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - args.gamma, - args.tau, - qf1_optimizer, - qf1_optimizer_state, - actor_optimizer, + actor_parameters, + actor_target_parameters, + qf1_parameters, + qf1_target_parameters, actor_optimizer_state, ) - - # print(actor_parameters["params"]["Dense_0"]["kernel"].sum()) if global_step % 100 == 0: - # print(qf1_target_parameters["params"]["Dense_0"]["kernel"].sum()) writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) From c72cfb7c3ee9624393c48ab16df25cf6644eb81d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 24 Jun 2022 22:52:45 -0400 Subject: [PATCH 08/27] Remove code --- cleanrl/ddpg_continuous_action_jax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index 48649ad11..ca5bd05bc 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -97,7 +97,6 @@ def __call__(self, x: jnp.ndarray, a: jnp.ndarray): class Actor(nn.Module): action_dim: Sequence[int] - @nn.compact def __call__(self, x): x = nn.Dense(256)(x) From bfece786fec4e78afbf856bf39f9d07c31ad5486 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 24 Jun 2022 22:55:21 -0400 Subject: [PATCH 09/27] formatting changes --- cleanrl/ddpg_continuous_action_jax.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index ca5bd05bc..c3d6f7cb3 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -11,7 +11,6 @@ import jax.numpy as jnp import numpy as np import optax -import torch from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter @@ -136,20 +135,14 @@ def update_target(src, dst, tau): # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.backends.cudnn.deterministic = args.torch_deterministic jaxRNG = jax.random.PRNGKey(0) - device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") - # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" max_action = float(envs.single_action_space.high[0]) - # actor = Actor(envs).to(device) - envs.single_observation_space.dtype = np.float32 rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device="cpu") start_time = time.time() @@ -252,7 +245,7 @@ def actor_loss(actor_parameters, qf1_parameters, observations): # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs - # # ALGO LOGIC: training. + # ALGO LOGIC: training. if global_step > args.learning_starts: data = rb.sample(args.batch_size) qf1_loss_value, qf1_parameters, qf1_optimizer_state = update_critic( From 0710728a18560c25e0120a8044d6624ffa13dd81 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 24 Jun 2022 23:01:47 -0400 Subject: [PATCH 10/27] formatting change --- cleanrl/ddpg_continuous_action_jax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index c3d6f7cb3..e4f51ef4e 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -82,7 +82,6 @@ def thunk(): class QNetwork(nn.Module): obs_dim: Sequence[int] action_dim: Sequence[int] - @nn.compact def __call__(self, x: jnp.ndarray, a: jnp.ndarray): x = jnp.concatenate([x, a], -1) From 92d9d1301955aead425694b70fd814f70147998d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 25 Jun 2022 10:21:12 -0400 Subject: [PATCH 11/27] bug fix --- cleanrl/ddpg_continuous_action_jax.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index e4f51ef4e..67bad7709 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -80,8 +80,6 @@ def thunk(): # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): - obs_dim: Sequence[int] - action_dim: Sequence[int] @nn.compact def __call__(self, x: jnp.ndarray, a: jnp.ndarray): x = jnp.concatenate([x, a], -1) @@ -134,7 +132,7 @@ def update_target(src, dst, tau): # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) - jaxRNG = jax.random.PRNGKey(0) + jaxRNG = jax.random.PRNGKey(args.seed) # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) @@ -153,7 +151,7 @@ def update_target(src, dst, tau): actor_parameters = actor.init(jaxRNG, obs) actor_target_parameters = actor.init(jaxRNG, obs) actor.apply = jax.jit(actor.apply) - qf1 = QNetwork(obs_dim=np.prod(envs.single_observation_space.shape), action_dim=np.prod(envs.single_action_space.shape)) + qf1 = QNetwork() qf1_parameters = qf1.init(jaxRNG, obs, envs.action_space.sample()) qf1_target_parameters = qf1.init(jaxRNG, obs, envs.action_space.sample()) qf1.apply = jax.jit(qf1.apply) From cc6e2fa47b5a9f397db42daaa9e3a909ef355b2a Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 25 Jun 2022 17:29:52 -0400 Subject: [PATCH 12/27] Prototype JAX + PPO + envpool's MuJoCo --- cleanrl/ppo_continuous_action_envpool_jax.py | 362 +++++++++++++++++++ 1 file changed, 362 insertions(+) create mode 100644 cleanrl/ppo_continuous_action_envpool_jax.py diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py new file mode 100644 index 000000000..c537c5c5a --- /dev/null +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -0,0 +1,362 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy +import argparse +import os +import random +import time +from collections import deque +from distutils.util import strtobool +from typing import Sequence + +import envpool +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=10000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=3e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=64, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=128, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Use GAE for advantage computation") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=2, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=4, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.1, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +class RecordEpisodeStatistics(gym.Wrapper): + def __init__(self, env, deque_size=100): + super().__init__(env) + self.num_envs = getattr(env, "num_envs", 1) + self.episode_returns = None + self.episode_lengths = None + + def reset(self, **kwargs): + observations = super().reset(**kwargs) + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return observations + + def step(self, action): + observations, rewards, dones, infos = super().step(action) + self.episode_returns += rewards + self.episode_lengths += 1 + self.returned_episode_returns[:] = self.episode_returns + self.returned_episode_lengths[:] = self.episode_lengths + self.episode_returns *= 1 - dones + self.episode_lengths *= 1 - dones + infos["r"] = self.returned_episode_returns + infos["l"] = self.returned_episode_lengths + return ( + observations, + rewards, + dones, + infos, + ) + + +# def layer_init(layer, std=np.sqrt(2), bias_const=0.0): +# torch.nn.init.orthogonal_(layer.weight, std) +# torch.nn.init.constant_(layer.bias, bias_const) +# return layer + + +class Critic(nn.Module): + @nn.compact + def __call__(self, x): + critic = nn.Dense(64)(x) + critic = nn.tanh(critic) + critic = nn.Dense(64)(critic) + critic = nn.tanh(critic) + critic = nn.Dense(1)(critic) + return critic + + +class Actor(nn.Module): + action_dim: Sequence[int] + + @nn.compact + def __call__(self, x): + actor_mean = nn.Dense(64)(x) + actor_mean = nn.tanh(actor_mean) + actor_mean = nn.Dense(64)(actor_mean) + actor_mean = nn.tanh(actor_mean) + actor_mean = nn.Dense(self.action_dim)(actor_mean) + actor_logstd = jnp.zeros((1, self.action_dim)) + return actor_mean, actor_logstd + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + jaxRNG = jax.random.PRNGKey(args.seed) + + # env setup + envs = envpool.make( + args.env_id, + env_type="gym", + num_envs=args.num_envs, + ) + envs.num_envs = args.num_envs + envs.single_action_space = envs.action_space + envs.single_observation_space = envs.observation_space + envs.is_vector_env = True + envs = RecordEpisodeStatistics(envs) + envs = gym.wrappers.ClipAction(envs) + envs = gym.wrappers.NormalizeObservation(envs) + envs = gym.wrappers.TransformObservation(envs, lambda obs: np.clip(obs, -10, 10)) + envs = gym.wrappers.NormalizeReward(envs) + envs = gym.wrappers.TransformReward(envs, lambda reward: np.clip(reward, -10, 10)) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) + actor_parameters = actor.init(jaxRNG, envs.single_observation_space.sample()) + actor.apply = jax.jit(actor.apply) + critic = Critic() + critic_parameters = critic.init(jaxRNG, envs.single_observation_space.sample()) + critic.apply = jax.jit(critic.apply) + actor_optimizer = optax.adam(learning_rate=args.learning_rate, eps=1e-5) + actor_optimizer_state = actor_optimizer.init((actor_parameters, critic_parameters)) + + # ALGO Logic: Storage setup + obs = jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape) + actions = jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape) + logprobs = jnp.zeros((args.num_steps, args.num_envs)) + rewards = jnp.zeros((args.num_steps, args.num_envs)) + dones = jnp.zeros((args.num_steps, args.num_envs)) + values = jnp.zeros((args.num_steps, args.num_envs)) + avg_returns = deque(maxlen=20) + + @jax.jit + def get_action_and_value(x, obs, actions, logprobs, values, step, jaxRNG, actor_parameters, critic_parameters): + obs = obs.at[step].set(x) # inside jit() `x = x.at[idx].set(y)` is in-place. + action_mean, action_logstd = actor.apply(actor_parameters, x) + action_std = jnp.exp(action_logstd) + action = action_mean + action_std * jax.random.normal(jaxRNG, shape=action_mean.shape) + logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd + entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e) + value = critic.apply(critic_parameters, x) + actions = actions.at[step].set(action) + logprobs = logprobs.at[step].set(logprob.sum(1)) + values = values.at[step].set(value.squeeze()) + return x, obs, actions, logprobs, values, action, logprob, entropy, value + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = envs.reset() + next_done = np.zeros(args.num_envs) + num_updates = args.total_timesteps // args.batch_size + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + # if args.anneal_lr: + # frac = 1.0 - (update - 1.0) / num_updates + # lrnow = frac * args.learning_rate + # optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + next_obs, obs, actions, logprobs, values, action, logprob, entropy, value = get_action_and_value( + next_obs, obs, actions, logprobs, values, step, jaxRNG, actor_parameters, critic_parameters + ) + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, next_done, info = envs.step(np.array(action)) + for idx, d in enumerate(next_done): + if d: + print(f"global_step={global_step}, episodic_return={info['r'][idx]}") + avg_returns.append(info["r"][idx]) + writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step) + writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) + writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) + + # # bootstrap value if not done + # with torch.no_grad(): + # next_value = agent.get_value(next_obs).reshape(1, -1) + # if args.gae: + # advantages = torch.zeros_like(rewards).to(device) + # lastgaelam = 0 + # for t in reversed(range(args.num_steps)): + # if t == args.num_steps - 1: + # nextnonterminal = 1.0 - next_done + # nextvalues = next_value + # else: + # nextnonterminal = 1.0 - dones[t + 1] + # nextvalues = values[t + 1] + # delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + # advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + # returns = advantages + values + # else: + # returns = torch.zeros_like(rewards).to(device) + # for t in reversed(range(args.num_steps)): + # if t == args.num_steps - 1: + # nextnonterminal = 1.0 - next_done + # next_return = next_value + # else: + # nextnonterminal = 1.0 - dones[t + 1] + # next_return = returns[t + 1] + # returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return + # advantages = returns - values + + # # flatten the batch + # b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + # b_logprobs = logprobs.reshape(-1) + # b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + # b_advantages = advantages.reshape(-1) + # b_returns = returns.reshape(-1) + # b_values = values.reshape(-1) + + # # Optimizing the policy and value network + # b_inds = np.arange(args.batch_size) + # clipfracs = [] + # for epoch in range(args.update_epochs): + # np.random.shuffle(b_inds) + # for start in range(0, args.batch_size, args.minibatch_size): + # end = start + args.minibatch_size + # mb_inds = b_inds[start:end] + + # _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) + # logratio = newlogprob - b_logprobs[mb_inds] + # ratio = logratio.exp() + + # with torch.no_grad(): + # # calculate approx_kl http://joschu.net/blog/kl-approx.html + # old_approx_kl = (-logratio).mean() + # approx_kl = ((ratio - 1) - logratio).mean() + # clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + # mb_advantages = b_advantages[mb_inds] + # if args.norm_adv: + # mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # # Policy loss + # pg_loss1 = -mb_advantages * ratio + # pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + # pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # # Value loss + # newvalue = newvalue.view(-1) + # if args.clip_vloss: + # v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + # v_clipped = b_values[mb_inds] + torch.clamp( + # newvalue - b_values[mb_inds], + # -args.clip_coef, + # args.clip_coef, + # ) + # v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + # v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + # v_loss = 0.5 * v_loss_max.mean() + # else: + # v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + # entropy_loss = entropy.mean() + # loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + # optimizer.zero_grad() + # loss.backward() + # nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + # optimizer.step() + + # if args.target_kl is not None: + # if approx_kl > args.target_kl: + # break + + # y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + # var_y = np.var(y_true) + # explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # # TRY NOT TO MODIFY: record rewards for plotting purposes + # writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + # writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + # writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + # writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + # writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + # writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + # writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + # writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + writer.close() From c769efcec42b624534cf2508a64872b1ce62af18 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 26 Jun 2022 12:17:28 -0400 Subject: [PATCH 13/27] next step --- cleanrl/ppo_continuous_action_envpool_jax.py | 254 ++++++++++--------- 1 file changed, 137 insertions(+), 117 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index c537c5c5a..311c90144 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -7,7 +7,10 @@ from distutils.util import strtobool from typing import Sequence +from cv2 import log + import envpool +import flax import flax.linen as nn import gym import jax @@ -42,11 +45,11 @@ def parse_args(): help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=10000000, help="total timesteps of the experiments") - parser.add_argument("--learning-rate", type=float, default=3e-4, + parser.add_argument("--learning-rate", type=float, default=0.00295, help="the learning rate of the optimizer") parser.add_argument("--num-envs", type=int, default=64, help="the number of parallel game environments") - parser.add_argument("--num-steps", type=int, default=128, + parser.add_argument("--num-steps", type=int, default=64, help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") @@ -56,21 +59,21 @@ def parse_args(): help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, help="the lambda for the general advantage estimation") - parser.add_argument("--num-minibatches", type=int, default=2, + parser.add_argument("--num-minibatches", type=int, default=4, help="the number of mini-batches") - parser.add_argument("--update-epochs", type=int, default=4, + parser.add_argument("--update-epochs", type=int, default=2, help="the K epochs to update the policy") parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggles advantages normalization") - parser.add_argument("--clip-coef", type=float, default=0.1, + parser.add_argument("--clip-coef", type=float, default=0.2, help="the surrogate clipping coefficient") - parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") - parser.add_argument("--ent-coef", type=float, default=0.01, + parser.add_argument("--ent-coef", type=float, default=0.0, help="coefficient of the entropy") - parser.add_argument("--vf-coef", type=float, default=0.5, + parser.add_argument("--vf-coef", type=float, default=1.3, help="coefficient of the value function") - parser.add_argument("--max-grad-norm", type=float, default=0.5, + parser.add_argument("--max-grad-norm", type=float, default=3.5, help="the maximum norm for the gradient clipping") parser.add_argument("--target-kl", type=float, default=None, help="the target KL divergence threshold") @@ -133,7 +136,6 @@ def __call__(self, x): class Actor(nn.Module): action_dim: Sequence[int] - @nn.compact def __call__(self, x): actor_mean = nn.Dense(64)(x) @@ -144,6 +146,11 @@ def __call__(self, x): actor_logstd = jnp.zeros((1, self.action_dim)) return actor_mean, actor_logstd +@flax.struct.dataclass +class AgentParams: + actor_params: flax.core.FrozenDict + critic_params: flax.core.FrozenDict + if __name__ == "__main__": args = parse_args() @@ -169,7 +176,8 @@ def __call__(self, x): # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) - jaxRNG = jax.random.PRNGKey(args.seed) + key = jax.random.PRNGKey(args.seed) + key, actor_key, critic_key = jax.random.split(key, 3) # env setup envs = envpool.make( @@ -190,13 +198,17 @@ def __call__(self, x): assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) - actor_parameters = actor.init(jaxRNG, envs.single_observation_space.sample()) + actor_params = actor.init(actor_key, envs.single_observation_space.sample()) actor.apply = jax.jit(actor.apply) critic = Critic() - critic_parameters = critic.init(jaxRNG, envs.single_observation_space.sample()) + critic_params = critic.init(critic_key, envs.single_observation_space.sample()) critic.apply = jax.jit(critic.apply) - actor_optimizer = optax.adam(learning_rate=args.learning_rate, eps=1e-5) - actor_optimizer_state = actor_optimizer.init((actor_parameters, critic_parameters)) + agent_optimizer = optax.adam(learning_rate=args.learning_rate, eps=1e-5) + agent_params = AgentParams( + actor_params, + critic_params, + ) + agent_optimizer_state = agent_optimizer.init(agent_params) # ALGO Logic: Storage setup obs = jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape) @@ -205,21 +217,106 @@ def __call__(self, x): rewards = jnp.zeros((args.num_steps, args.num_envs)) dones = jnp.zeros((args.num_steps, args.num_envs)) values = jnp.zeros((args.num_steps, args.num_envs)) + advantages = jnp.zeros((args.num_steps, args.num_envs)) avg_returns = deque(maxlen=20) @jax.jit - def get_action_and_value(x, obs, actions, logprobs, values, step, jaxRNG, actor_parameters, critic_parameters): + def get_action_and_value(x, obs, actions, logprobs, values, step, agent_params, key): obs = obs.at[step].set(x) # inside jit() `x = x.at[idx].set(y)` is in-place. - action_mean, action_logstd = actor.apply(actor_parameters, x) + action_mean, action_logstd = actor.apply(agent_params.actor_params, x) action_std = jnp.exp(action_logstd) - action = action_mean + action_std * jax.random.normal(jaxRNG, shape=action_mean.shape) + key, subkey = jax.random.split(key) + action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape) logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e) - value = critic.apply(critic_parameters, x) + value = critic.apply(agent_params.critic_params, x) actions = actions.at[step].set(action) logprobs = logprobs.at[step].set(logprob.sum(1)) values = values.at[step].set(value.squeeze()) - return x, obs, actions, logprobs, values, action, logprob, entropy, value + return x, obs, actions, logprobs, values, action, logprob, entropy, value, key + + @jax.jit + def get_action_and_value2(x, action, agent_params): + action_mean, action_logstd = actor.apply(agent_params.actor_params, x) + action_std = jnp.exp(action_logstd) + logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd + entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e) + value = critic.apply(agent_params.critic_params, x) + return logprob.sum(1), entropy, value + + @jax.jit + def compute_gae(next_obs, next_done, rewards, dones, values, advantages, agent_params): + advantages = advantages.at[:].set(0.0) # reset advantages + next_value = critic.apply(agent_params.critic_params, next_obs).squeeze() + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + advantages = advantages.at[t].set(lastgaelam) + returns = advantages + values + return jax.lax.stop_gradient(advantages), jax.lax.stop_gradient(returns) + + @jax.jit + def update_ppo( + obs, + logprobs, + actions, + advantages, + returns, + values, + agent_params, + agent_optimizer_state, + key + ): + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + def ppo_loss(agent_params, x, a, logp, adv, ret): + newlogprob, _, newvalue = get_action_and_value2(x, a, agent_params) + logratio = newlogprob - logp + ratio = jnp.exp(logratio) + + mb_advantages = adv + # if args.norm_adv: + # mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() + + # Value loss + v_loss = 0.5 * ((newvalue - ret) ** 2).mean() + + # entropy_loss = entropy.mean() + # loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + loss = pg_loss + v_loss * args.vf_coef + return loss + ppo_loss_grad_fn = jax.value_and_grad(ppo_loss) + + b_inds = jnp.arange(args.batch_size) + # clipfracs = [] + for _ in range(args.update_epochs): + key, subkey = jax.random.split(key) + b_inds = jax.random.shuffle(subkey, b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + loss, grads = ppo_loss_grad_fn(agent_params, b_obs[mb_inds], b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], b_returns[mb_inds]) + updates, agent_optimizer_state = agent_optimizer.update(grads, agent_optimizer_state) + agent_params = optax.apply_updates(agent_params, updates) + + return loss, key, agent_params, agent_optimizer_state # TRY NOT TO MODIFY: start the game global_step = 0 @@ -237,8 +334,8 @@ def get_action_and_value(x, obs, actions, logprobs, values, step, jaxRNG, actor_ for step in range(0, args.num_steps): global_step += 1 * args.num_envs - next_obs, obs, actions, logprobs, values, action, logprob, entropy, value = get_action_and_value( - next_obs, obs, actions, logprobs, values, step, jaxRNG, actor_parameters, critic_parameters + next_obs, obs, actions, logprobs, values, action, logprob, entropy, value, key = get_action_and_value( + next_obs, obs, actions, logprobs, values, step, agent_params, key ) # TRY NOT TO MODIFY: execute the game and log data. @@ -251,100 +348,23 @@ def get_action_and_value(x, obs, actions, logprobs, values, step, jaxRNG, actor_ writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) - # # bootstrap value if not done - # with torch.no_grad(): - # next_value = agent.get_value(next_obs).reshape(1, -1) - # if args.gae: - # advantages = torch.zeros_like(rewards).to(device) - # lastgaelam = 0 - # for t in reversed(range(args.num_steps)): - # if t == args.num_steps - 1: - # nextnonterminal = 1.0 - next_done - # nextvalues = next_value - # else: - # nextnonterminal = 1.0 - dones[t + 1] - # nextvalues = values[t + 1] - # delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - # advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - # returns = advantages + values - # else: - # returns = torch.zeros_like(rewards).to(device) - # for t in reversed(range(args.num_steps)): - # if t == args.num_steps - 1: - # nextnonterminal = 1.0 - next_done - # next_return = next_value - # else: - # nextnonterminal = 1.0 - dones[t + 1] - # next_return = returns[t + 1] - # returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - # advantages = returns - values - - # # flatten the batch - # b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) - # b_logprobs = logprobs.reshape(-1) - # b_actions = actions.reshape((-1,) + envs.single_action_space.shape) - # b_advantages = advantages.reshape(-1) - # b_returns = returns.reshape(-1) - # b_values = values.reshape(-1) - - # # Optimizing the policy and value network - # b_inds = np.arange(args.batch_size) - # clipfracs = [] - # for epoch in range(args.update_epochs): - # np.random.shuffle(b_inds) - # for start in range(0, args.batch_size, args.minibatch_size): - # end = start + args.minibatch_size - # mb_inds = b_inds[start:end] - - # _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) - # logratio = newlogprob - b_logprobs[mb_inds] - # ratio = logratio.exp() - - # with torch.no_grad(): - # # calculate approx_kl http://joschu.net/blog/kl-approx.html - # old_approx_kl = (-logratio).mean() - # approx_kl = ((ratio - 1) - logratio).mean() - # clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] - - # mb_advantages = b_advantages[mb_inds] - # if args.norm_adv: - # mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) - - # # Policy loss - # pg_loss1 = -mb_advantages * ratio - # pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) - # pg_loss = torch.max(pg_loss1, pg_loss2).mean() - - # # Value loss - # newvalue = newvalue.view(-1) - # if args.clip_vloss: - # v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 - # v_clipped = b_values[mb_inds] + torch.clamp( - # newvalue - b_values[mb_inds], - # -args.clip_coef, - # args.clip_coef, - # ) - # v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 - # v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) - # v_loss = 0.5 * v_loss_max.mean() - # else: - # v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() - - # entropy_loss = entropy.mean() - # loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef - - # optimizer.zero_grad() - # loss.backward() - # nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) - # optimizer.step() - - # if args.target_kl is not None: - # if approx_kl > args.target_kl: - # break - - # y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() - # var_y = np.var(y_true) - # explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + advantages, returns = compute_gae(next_obs, next_done, rewards, dones, values, advantages, agent_params) + # print(advantages.sum(), returns.sum()) + # raise + loss, key, agent_params, agent_optimizer_state = update_ppo( + obs, + logprobs, + actions, + advantages, + returns, + values, + agent_params, + agent_optimizer_state, + key, + ) + + print(agent_params.actor_params["params"]["Dense_0"]["kernel"].sum(), agent_params.critic_params["params"]["Dense_0"]["kernel"].sum()) + # # TRY NOT TO MODIFY: record rewards for plotting purposes # writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) @@ -354,7 +374,7 @@ def get_action_and_value(x, obs, actions, logprobs, values, step, jaxRNG, actor_ # writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) # writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) # writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) - # writer.add_scalar("losses/explained_variance", explained_var, global_step) + writer.add_scalar("losses/loss", loss.item(), global_step) print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) From 30c4dde3c4c0ccd2f7bbe84942e6f927c89fdd48 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 26 Jun 2022 20:19:08 -0400 Subject: [PATCH 14/27] successful prototype --- cleanrl/ppo_continuous_action_envpool_jax.py | 110 +++++++++---------- 1 file changed, 54 insertions(+), 56 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 311c90144..9cced2e23 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -7,11 +7,10 @@ from distutils.util import strtobool from typing import Sequence -from cv2 import log - import envpool import flax import flax.linen as nn +from flax.linen.initializers import orthogonal, constant import gym import jax import jax.numpy as jnp @@ -41,7 +40,7 @@ def parse_args(): help="weather to capture videos of the agent performances (check out `videos` folder)") # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", + parser.add_argument("--env-id", type=str, default="Ant-v4", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=10000000, help="total timesteps of the experiments") @@ -116,21 +115,14 @@ def step(self, action): infos, ) - -# def layer_init(layer, std=np.sqrt(2), bias_const=0.0): -# torch.nn.init.orthogonal_(layer.weight, std) -# torch.nn.init.constant_(layer.bias, bias_const) -# return layer - - class Critic(nn.Module): @nn.compact def __call__(self, x): - critic = nn.Dense(64)(x) + critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) critic = nn.tanh(critic) - critic = nn.Dense(64)(critic) + critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(critic) critic = nn.tanh(critic) - critic = nn.Dense(1)(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(critic) return critic @@ -138,18 +130,19 @@ class Actor(nn.Module): action_dim: Sequence[int] @nn.compact def __call__(self, x): - actor_mean = nn.Dense(64)(x) + actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) actor_mean = nn.tanh(actor_mean) - actor_mean = nn.Dense(64)(actor_mean) + actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(actor_mean) actor_mean = nn.tanh(actor_mean) - actor_mean = nn.Dense(self.action_dim)(actor_mean) - actor_logstd = jnp.zeros((1, self.action_dim)) + actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean) + actor_logstd = self.param('actor_logstd', constant(0.0), (1, self.action_dim)) return actor_mean, actor_logstd + @flax.struct.dataclass class AgentParams: - actor_params: flax.core.FrozenDict - critic_params: flax.core.FrozenDict + actor_params: flax.core.FrozenDict + critic_params: flax.core.FrozenDict if __name__ == "__main__": @@ -199,11 +192,16 @@ class AgentParams: actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) actor_params = actor.init(actor_key, envs.single_observation_space.sample()) + print(actor.tabulate(jax.random.PRNGKey(0), envs.single_observation_space.sample())) actor.apply = jax.jit(actor.apply) critic = Critic() critic_params = critic.init(critic_key, envs.single_observation_space.sample()) critic.apply = jax.jit(critic.apply) - agent_optimizer = optax.adam(learning_rate=args.learning_rate, eps=1e-5) + + agent_optimizer = optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.inject_hyperparams(optax.adam)(learning_rate=args.learning_rate, eps=1e-5), + ) agent_params = AgentParams( actor_params, critic_params, @@ -214,7 +212,7 @@ class AgentParams: obs = jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape) actions = jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape) logprobs = jnp.zeros((args.num_steps, args.num_envs)) - rewards = jnp.zeros((args.num_steps, args.num_envs)) + rewards = np.zeros((args.num_steps, args.num_envs)) dones = jnp.zeros((args.num_steps, args.num_envs)) values = jnp.zeros((args.num_steps, args.num_envs)) advantages = jnp.zeros((args.num_steps, args.num_envs)) @@ -224,6 +222,7 @@ class AgentParams: def get_action_and_value(x, obs, actions, logprobs, values, step, agent_params, key): obs = obs.at[step].set(x) # inside jit() `x = x.at[idx].set(y)` is in-place. action_mean, action_logstd = actor.apply(agent_params.actor_params, x) + # action_logstd = (jnp.ones_like(action_mean) * action_logstd) action_std = jnp.exp(action_logstd) key, subkey = jax.random.split(key) action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape) @@ -246,7 +245,7 @@ def get_action_and_value2(x, action, agent_params): @jax.jit def compute_gae(next_obs, next_done, rewards, dones, values, advantages, agent_params): - advantages = advantages.at[:].set(0.0) # reset advantages + advantages = advantages.at[:].set(0.0) # reset advantages next_value = critic.apply(agent_params.critic_params, next_obs).squeeze() lastgaelam = 0 for t in reversed(range(args.num_steps)): @@ -263,32 +262,22 @@ def compute_gae(next_obs, next_done, rewards, dones, values, advantages, agent_p return jax.lax.stop_gradient(advantages), jax.lax.stop_gradient(returns) @jax.jit - def update_ppo( - obs, - logprobs, - actions, - advantages, - returns, - values, - agent_params, - agent_optimizer_state, - key - ): + def update_ppo(obs, logprobs, actions, advantages, returns, values, agent_params, agent_optimizer_state, key): b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) b_logprobs = logprobs.reshape(-1) b_actions = actions.reshape((-1,) + envs.single_action_space.shape) b_advantages = advantages.reshape(-1) b_returns = returns.reshape(-1) - b_values = values.reshape(-1) + values.reshape(-1) - def ppo_loss(agent_params, x, a, logp, adv, ret): + def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): newlogprob, _, newvalue = get_action_and_value2(x, a, agent_params) logratio = newlogprob - logp ratio = jnp.exp(logratio) + approx_kl = ((ratio - 1) - logratio).mean() - mb_advantages = adv - # if args.norm_adv: - # mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) # Policy loss pg_loss1 = -mb_advantages * ratio @@ -296,13 +285,14 @@ def ppo_loss(agent_params, x, a, logp, adv, ret): pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() # Value loss - v_loss = 0.5 * ((newvalue - ret) ** 2).mean() + v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() # entropy_loss = entropy.mean() # loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss + v_loss * args.vf_coef - return loss - ppo_loss_grad_fn = jax.value_and_grad(ppo_loss) + return loss, (pg_loss, v_loss, approx_kl) + + ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) b_inds = jnp.arange(args.batch_size) # clipfracs = [] @@ -312,11 +302,18 @@ def ppo_loss(agent_params, x, a, logp, adv, ret): for start in range(0, args.batch_size, args.minibatch_size): end = start + args.minibatch_size mb_inds = b_inds[start:end] - loss, grads = ppo_loss_grad_fn(agent_params, b_obs[mb_inds], b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], b_returns[mb_inds]) + (loss, (pg_loss, v_loss, approx_kl)), grads = ppo_loss_grad_fn( + agent_params, + b_obs[mb_inds], + b_actions[mb_inds], + b_logprobs[mb_inds], + b_advantages[mb_inds], + b_returns[mb_inds], + ) updates, agent_optimizer_state = agent_optimizer.update(grads, agent_optimizer_state) agent_params = optax.apply_updates(agent_params, updates) - - return loss, key, agent_params, agent_optimizer_state + + return loss, pg_loss, v_loss, approx_kl, key, agent_params, agent_optimizer_state # TRY NOT TO MODIFY: start the game global_step = 0 @@ -327,10 +324,11 @@ def ppo_loss(agent_params, x, a, logp, adv, ret): for update in range(1, num_updates + 1): # Annealing the rate if instructed to do so. - # if args.anneal_lr: - # frac = 1.0 - (update - 1.0) / num_updates - # lrnow = frac * args.learning_rate - # optimizer.param_groups[0]["lr"] = lrnow + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + agent_optimizer_state[1].hyperparams['learning_rate'] = lrnow + agent_optimizer.update(agent_params, agent_optimizer_state) for step in range(0, args.num_steps): global_step += 1 * args.num_envs @@ -347,11 +345,10 @@ def ppo_loss(agent_params, x, a, logp, adv, ret): writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) + rewards[step] = reward advantages, returns = compute_gae(next_obs, next_done, rewards, dones, values, advantages, agent_params) - # print(advantages.sum(), returns.sum()) - # raise - loss, key, agent_params, agent_optimizer_state = update_ppo( + loss, pg_loss, v_loss, approx_kl, key, agent_params, agent_optimizer_state = update_ppo( obs, logprobs, actions, @@ -363,16 +360,17 @@ def ppo_loss(agent_params, x, a, logp, adv, ret): key, ) - print(agent_params.actor_params["params"]["Dense_0"]["kernel"].sum(), agent_params.critic_params["params"]["Dense_0"]["kernel"].sum()) - + # print(agent_params.actor_params["params"]) + # print(agent_params.actor_params['params']['actor_logstd']) + # print(agent_params.actor_params["params"]["Dense_0"]["kernel"].sum(), agent_params.critic_params["params"]["Dense_0"]["kernel"].sum()) # # TRY NOT TO MODIFY: record rewards for plotting purposes # writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) - # writer.add_scalar("losses/value_loss", v_loss.item(), global_step) - # writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) # writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) # writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) - # writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) # writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) writer.add_scalar("losses/loss", loss.item(), global_step) print("SPS:", int(global_step / (time.time() - start_time))) From 25397eca9d4a110b7b3c74e4ad5ae31141446c36 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 26 Jun 2022 20:22:32 -0400 Subject: [PATCH 15/27] remove ddpg --- cleanrl/ddpg_continuous_action_jax.py | 283 -------------------------- 1 file changed, 283 deletions(-) delete mode 100644 cleanrl/ddpg_continuous_action_jax.py diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py deleted file mode 100644 index 67bad7709..000000000 --- a/cleanrl/ddpg_continuous_action_jax.py +++ /dev/null @@ -1,283 +0,0 @@ -import argparse -import os -import random -import time -from distutils.util import strtobool -from typing import Sequence - -import flax.linen as nn -import gym -import jax -import jax.numpy as jnp -import numpy as np -import optax -from stable_baselines3.common.buffers import ReplayBuffer -from torch.utils.tensorboard import SummaryWriter - - -def parse_args(): - # fmt: off - parser = argparse.ArgumentParser() - parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), - help="the name of this experiment") - parser.add_argument("--seed", type=int, default=1, - help="seed of the experiment") - parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, `torch.backends.cudnn.deterministic=False`") - parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, cuda will be enabled by default") - parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="if toggled, this experiment will be tracked with Weights and Biases") - parser.add_argument("--wandb-project-name", type=str, default="cleanRL", - help="the wandb's project name") - parser.add_argument("--wandb-entity", type=str, default=None, - help="the entity (team) of wandb's project") - parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="weather to capture videos of the agent performances (check out `videos` folder)") - - # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", - help="the id of the environment") - parser.add_argument("--total-timesteps", type=int, default=1000000, - help="total timesteps of the experiments") - parser.add_argument("--learning-rate", type=float, default=3e-4, - help="the learning rate of the optimizer") - parser.add_argument("--buffer-size", type=int, default=int(1e6), - help="the replay memory buffer size") - parser.add_argument("--gamma", type=float, default=0.99, - help="the discount factor gamma") - parser.add_argument("--tau", type=float, default=0.005, - help="target smoothing coefficient (default: 0.005)") - parser.add_argument("--batch-size", type=int, default=256, - help="the batch size of sample from the reply memory") - parser.add_argument("--exploration-noise", type=float, default=0.1, - help="the scale of exploration noise") - parser.add_argument("--learning-starts", type=int, default=25e3, - help="timestep to start learning") - parser.add_argument("--policy-frequency", type=int, default=2, - help="the frequency of training policy (delayed)") - parser.add_argument("--noise-clip", type=float, default=0.5, - help="noise clip parameter of the Target Policy Smoothing Regularization") - args = parser.parse_args() - # fmt: on - return args - - -def make_env(env_id, seed, idx, capture_video, run_name): - def thunk(): - env = gym.make(env_id) - env = gym.wrappers.RecordEpisodeStatistics(env) - if capture_video: - if idx == 0: - env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") - env.seed(seed) - env.action_space.seed(seed) - env.observation_space.seed(seed) - return env - - return thunk - - -# ALGO LOGIC: initialize agent here: -class QNetwork(nn.Module): - @nn.compact - def __call__(self, x: jnp.ndarray, a: jnp.ndarray): - x = jnp.concatenate([x, a], -1) - x = nn.Dense(256)(x) - x = nn.relu(x) - x = nn.Dense(256)(x) - x = nn.relu(x) - x = nn.Dense(1)(x) - return x - - -class Actor(nn.Module): - action_dim: Sequence[int] - @nn.compact - def __call__(self, x): - x = nn.Dense(256)(x) - x = nn.relu(x) - x = nn.Dense(256)(x) - x = nn.relu(x) - x = nn.Dense(self.action_dim)(x) - x = nn.tanh(x) - return x - - -def update_target(src, dst, tau): - return jax.tree_map(lambda p, tp: p * tau + tp * (1 - tau), src, dst) - - -if __name__ == "__main__": - args = parse_args() - run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=vars(args), - name=run_name, - monitor_gym=True, - save_code=True, - ) - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - - # TRY NOT TO MODIFY: seeding - random.seed(args.seed) - np.random.seed(args.seed) - jaxRNG = jax.random.PRNGKey(args.seed) - - # env setup - envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" - - max_action = float(envs.single_action_space.high[0]) - - envs.single_observation_space.dtype = np.float32 - rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device="cpu") - start_time = time.time() - - # TRY NOT TO MODIFY: start the game - obs = envs.reset() - - actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) - actor_parameters = actor.init(jaxRNG, obs) - actor_target_parameters = actor.init(jaxRNG, obs) - actor.apply = jax.jit(actor.apply) - qf1 = QNetwork() - qf1_parameters = qf1.init(jaxRNG, obs, envs.action_space.sample()) - qf1_target_parameters = qf1.init(jaxRNG, obs, envs.action_space.sample()) - qf1.apply = jax.jit(qf1.apply) - actor_target_parameters = update_target(actor_parameters, actor_target_parameters, 1.0) - qf1_target_parameters = update_target(qf1_parameters, qf1_target_parameters, 1.0) - actor_optimizer = optax.adam(learning_rate=args.learning_rate) - actor_optimizer_state = actor_optimizer.init(actor_parameters) - qf1_optimizer = optax.adam(learning_rate=args.learning_rate) - qf1_optimizer_state = qf1_optimizer.init(qf1_parameters) - - @jax.jit - def update_critic( - observations, - actions, - next_observations, - rewards, - dones, - actor_target_parameters, - qf1_parameters, - qf1_target_parameters, - qf1_optimizer_state, - ): - next_state_actions = (actor.apply(actor_target_parameters, next_observations)).clip(-1, 1) - qf1_next_target = qf1.apply(qf1_target_parameters, next_observations, next_state_actions).reshape(-1) - next_q_value = (rewards + (1 - dones) * args.gamma * (qf1_next_target)).reshape(-1) - - def mse_loss(qf1_parameters, observations, actions, next_q_value): - return ((qf1.apply(qf1_parameters, observations, actions).squeeze() - next_q_value) ** 2).mean() - - qf1_loss_value, grads = jax.value_and_grad(mse_loss)(qf1_parameters, observations, actions, next_q_value) - updates, qf1_optimizer_state = qf1_optimizer.update(grads, qf1_optimizer_state) - qf1_parameters = optax.apply_updates(qf1_parameters, updates) - return qf1_loss_value, qf1_parameters, qf1_optimizer_state - - @jax.jit - def update_actor( - observations, - actor_parameters, - actor_target_parameters, - qf1_parameters, - qf1_target_parameters, - actor_optimizer_state, - ): - def actor_loss(actor_parameters, qf1_parameters, observations): - return -qf1.apply(qf1_parameters, observations, actor.apply(actor_parameters, observations)).mean() - - actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_parameters, qf1_parameters, observations) - updates, actor_optimizer_state = actor_optimizer.update(grads, actor_optimizer_state) - actor_parameters = optax.apply_updates(actor_parameters, updates) - - actor_target_parameters = update_target(actor_parameters, actor_target_parameters, args.tau) - qf1_target_parameters = update_target(qf1_parameters, qf1_target_parameters, args.tau) - return actor_loss_value, actor_parameters, actor_optimizer_state, actor_target_parameters, qf1_target_parameters - - for global_step in range(args.total_timesteps): - # ALGO LOGIC: put action logic here - if global_step < args.learning_starts: - actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) - else: - actions = actor.apply(actor_parameters, obs) - actions = np.array( - [ - ( - np.array(actions)[0] - + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) - ).clip(envs.single_action_space.low, envs.single_action_space.high) - ] - ) - - # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, dones, infos = envs.step(actions) - - # TRY NOT TO MODIFY: record rewards for plotting purposes - for info in infos: - if "episode" in info.keys(): - print(f"global_step={global_step}, episodic_return={info['episode']['r']}") - writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) - writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) - break - - # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` - real_next_obs = next_obs.copy() - for idx, d in enumerate(dones): - if d: - real_next_obs[idx] = infos[idx]["terminal_observation"] - rb.add(obs, real_next_obs, actions, rewards, dones, infos) - - # TRY NOT TO MODIFY: CRUCIAL step easy to overlook - obs = next_obs - - # ALGO LOGIC: training. - if global_step > args.learning_starts: - data = rb.sample(args.batch_size) - qf1_loss_value, qf1_parameters, qf1_optimizer_state = update_critic( - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - actor_target_parameters, - qf1_parameters, - qf1_target_parameters, - qf1_optimizer_state, - ) - if global_step % args.policy_frequency == 0: - ( - actor_loss_value, - actor_parameters, - actor_optimizer_state, - actor_target_parameters, - qf1_target_parameters, - ) = update_actor( - data.observations.numpy(), - actor_parameters, - actor_target_parameters, - qf1_parameters, - qf1_target_parameters, - actor_optimizer_state, - ) - - if global_step % 100 == 0: - writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) - writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) - # writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - - envs.close() - writer.close() From 1f21964311f374c2cc5c20079d457e712588e67a Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 26 Jun 2022 20:30:09 -0400 Subject: [PATCH 16/27] pre-commit --- cleanrl/ppo_continuous_action_envpool_jax.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 9cced2e23..85d9391ba 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -10,12 +10,12 @@ import envpool import flax import flax.linen as nn -from flax.linen.initializers import orthogonal, constant import gym import jax import jax.numpy as jnp import numpy as np import optax +from flax.linen.initializers import constant, orthogonal from torch.utils.tensorboard import SummaryWriter @@ -115,6 +115,7 @@ def step(self, action): infos, ) + class Critic(nn.Module): @nn.compact def __call__(self, x): @@ -128,6 +129,7 @@ def __call__(self, x): class Actor(nn.Module): action_dim: Sequence[int] + @nn.compact def __call__(self, x): actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) @@ -135,7 +137,7 @@ def __call__(self, x): actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(actor_mean) actor_mean = nn.tanh(actor_mean) actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean) - actor_logstd = self.param('actor_logstd', constant(0.0), (1, self.action_dim)) + actor_logstd = self.param("actor_logstd", constant(0.0), (1, self.action_dim)) return actor_mean, actor_logstd @@ -327,7 +329,7 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): if args.anneal_lr: frac = 1.0 - (update - 1.0) / num_updates lrnow = frac * args.learning_rate - agent_optimizer_state[1].hyperparams['learning_rate'] = lrnow + agent_optimizer_state[1].hyperparams["learning_rate"] = lrnow agent_optimizer.update(agent_params, agent_optimizer_state) for step in range(0, args.num_steps): From 2bddebce2ab0f56b814d8c59a9ad4097f9f86134 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 26 Jun 2022 21:24:27 -0400 Subject: [PATCH 17/27] stop gradient for approxkl --- cleanrl/ppo_continuous_action_envpool_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 85d9391ba..4df60afd9 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -292,7 +292,7 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): # entropy_loss = entropy.mean() # loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef loss = pg_loss + v_loss * args.vf_coef - return loss, (pg_loss, v_loss, approx_kl) + return loss, (pg_loss, v_loss, jax.lax.stop_gradient(approx_kl)) ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) From 3f46f08d41476cfbe1e20a944a029f8c2c058442 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 26 Jun 2022 21:58:10 -0400 Subject: [PATCH 18/27] stupid bug: fill dones and always squeeze in MSE --- cleanrl/ppo_continuous_action_envpool_jax.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 4df60afd9..b86cd3f43 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -221,8 +221,9 @@ class AgentParams: avg_returns = deque(maxlen=20) @jax.jit - def get_action_and_value(x, obs, actions, logprobs, values, step, agent_params, key): + def get_action_and_value(x, d, obs, dones, actions, logprobs, values, step, agent_params, key): obs = obs.at[step].set(x) # inside jit() `x = x.at[idx].set(y)` is in-place. + dones = dones.at[step].set(d) action_mean, action_logstd = actor.apply(agent_params.actor_params, x) # action_logstd = (jnp.ones_like(action_mean) * action_logstd) action_std = jnp.exp(action_logstd) @@ -234,7 +235,7 @@ def get_action_and_value(x, obs, actions, logprobs, values, step, agent_params, actions = actions.at[step].set(action) logprobs = logprobs.at[step].set(logprob.sum(1)) values = values.at[step].set(value.squeeze()) - return x, obs, actions, logprobs, values, action, logprob, entropy, value, key + return obs, dones, actions, logprobs, values, action, logprob, entropy, value, key @jax.jit def get_action_and_value2(x, action, agent_params): @@ -242,7 +243,7 @@ def get_action_and_value2(x, action, agent_params): action_std = jnp.exp(action_logstd) logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e) - value = critic.apply(agent_params.critic_params, x) + value = critic.apply(agent_params.critic_params, x).squeeze() return logprob.sum(1), entropy, value @jax.jit @@ -334,8 +335,8 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): for step in range(0, args.num_steps): global_step += 1 * args.num_envs - next_obs, obs, actions, logprobs, values, action, logprob, entropy, value, key = get_action_and_value( - next_obs, obs, actions, logprobs, values, step, agent_params, key + obs, dones, actions, logprobs, values, action, logprob, entropy, value, key = get_action_and_value( + next_obs, next_done, obs, dones, actions, logprobs, values, step, agent_params, key ) # TRY NOT TO MODIFY: execute the game and log data. From a0c56d3e229f1a763dcfae3e7b6ae8d9b08ed9bf Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 26 Jun 2022 23:16:16 -0400 Subject: [PATCH 19/27] Speed up 70% w/ official optimizer scheulder API --- cleanrl/ppo_continuous_action_envpool_jax.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index b86cd3f43..41fb3a945 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -79,6 +79,7 @@ def parse_args(): args = parser.parse_args() args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_updates = args.total_timesteps // args.batch_size # fmt: on return args @@ -200,9 +201,14 @@ class AgentParams: critic_params = critic.init(critic_key, envs.single_observation_space.sample()) critic.apply = jax.jit(critic.apply) + def linear_schedule(count): + # anneal learning rate linearly after one training iteration which contains + # (args.num_minibatches * args.update_epochs) gradient updates + frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates + return args.learning_rate * frac agent_optimizer = optax.chain( optax.clip_by_global_norm(args.max_grad_norm), - optax.inject_hyperparams(optax.adam)(learning_rate=args.learning_rate, eps=1e-5), + optax.inject_hyperparams(optax.adam)(learning_rate=linear_schedule, eps=1e-5), ) agent_params = AgentParams( actor_params, @@ -323,16 +329,8 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): start_time = time.time() next_obs = envs.reset() next_done = np.zeros(args.num_envs) - num_updates = args.total_timesteps // args.batch_size - - for update in range(1, num_updates + 1): - # Annealing the rate if instructed to do so. - if args.anneal_lr: - frac = 1.0 - (update - 1.0) / num_updates - lrnow = frac * args.learning_rate - agent_optimizer_state[1].hyperparams["learning_rate"] = lrnow - agent_optimizer.update(agent_params, agent_optimizer_state) + for update in range(1, args.num_updates + 1): for step in range(0, args.num_steps): global_step += 1 * args.num_envs obs, dones, actions, logprobs, values, action, logprob, entropy, value, key = get_action_and_value( @@ -368,7 +366,7 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): # print(agent_params.actor_params["params"]["Dense_0"]["kernel"].sum(), agent_params.critic_params["params"]["Dense_0"]["kernel"].sum()) # # TRY NOT TO MODIFY: record rewards for plotting purposes - # writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("charts/learning_rate", agent_optimizer_state[1].hyperparams["learning_rate"], global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) # writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) From 84777b8135919b0e2f20c9717064070503fbae13 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 27 Jun 2022 00:00:17 -0400 Subject: [PATCH 20/27] record learning rate also --- cleanrl/ppo_continuous_action_envpool_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 41fb3a945..3ca1a80a4 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -366,7 +366,7 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): # print(agent_params.actor_params["params"]["Dense_0"]["kernel"].sum(), agent_params.critic_params["params"]["Dense_0"]["kernel"].sum()) # # TRY NOT TO MODIFY: record rewards for plotting purposes - writer.add_scalar("charts/learning_rate", agent_optimizer_state[1].hyperparams["learning_rate"], global_step) + writer.add_scalar("charts/learning_rate", agent_optimizer_state[1].hyperparams["learning_rate"].item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) # writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) From 399f9a36e9d48032102b68078b1ca1cfaa940ba6 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 27 Jun 2022 00:00:38 -0400 Subject: [PATCH 21/27] remove debug code --- cleanrl/ppo_continuous_action_envpool_jax.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 3ca1a80a4..577547220 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -361,10 +361,6 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): key, ) - # print(agent_params.actor_params["params"]) - # print(agent_params.actor_params['params']['actor_logstd']) - # print(agent_params.actor_params["params"]["Dense_0"]["kernel"].sum(), agent_params.critic_params["params"]["Dense_0"]["kernel"].sum()) - # # TRY NOT TO MODIFY: record rewards for plotting purposes writer.add_scalar("charts/learning_rate", agent_optimizer_state[1].hyperparams["learning_rate"].item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step) From 2d67459d2f0c892370724fd0e0260e89483cd139 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 6 Jul 2022 23:32:54 -0400 Subject: [PATCH 22/27] minor refactor --- cleanrl/ppo_continuous_action_envpool_jax.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 577547220..128c2d78b 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -194,11 +194,12 @@ class AgentParams: assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) - actor_params = actor.init(actor_key, envs.single_observation_space.sample()) - print(actor.tabulate(jax.random.PRNGKey(0), envs.single_observation_space.sample())) - actor.apply = jax.jit(actor.apply) critic = Critic() - critic_params = critic.init(critic_key, envs.single_observation_space.sample()) + agent_params = AgentParams( + actor.init(actor_key, envs.single_observation_space.sample()), + critic.init(critic_key, envs.single_observation_space.sample()), + ) + actor.apply = jax.jit(actor.apply) critic.apply = jax.jit(critic.apply) def linear_schedule(count): @@ -210,10 +211,6 @@ def linear_schedule(count): optax.clip_by_global_norm(args.max_grad_norm), optax.inject_hyperparams(optax.adam)(learning_rate=linear_schedule, eps=1e-5), ) - agent_params = AgentParams( - actor_params, - critic_params, - ) agent_optimizer_state = agent_optimizer.init(agent_params) # ALGO Logic: Storage setup @@ -303,11 +300,10 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) - b_inds = jnp.arange(args.batch_size) # clipfracs = [] for _ in range(args.update_epochs): key, subkey = jax.random.split(key) - b_inds = jax.random.shuffle(subkey, b_inds) + b_inds = jax.random.permutation(subkey, args.batch_size, independent=True) for start in range(0, args.batch_size, args.minibatch_size): end = start + args.minibatch_size mb_inds = b_inds[start:end] From 20933093db07007af79f13136790077395053b9b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 6 Jul 2022 23:34:16 -0400 Subject: [PATCH 23/27] use TrainState --- cleanrl/ppo_continuous_action_envpool_jax.py | 111 ++++++++++++------- 1 file changed, 70 insertions(+), 41 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 128c2d78b..40fde1565 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -15,6 +15,7 @@ import jax.numpy as jnp import numpy as np import optax +from flax.training.train_state import TrainState from flax.linen.initializers import constant, orthogonal from torch.utils.tensorboard import SummaryWriter @@ -193,25 +194,27 @@ class AgentParams: envs = gym.wrappers.TransformReward(envs, lambda reward: np.clip(reward, -10, 10)) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" - actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) - critic = Critic() - agent_params = AgentParams( - actor.init(actor_key, envs.single_observation_space.sample()), - critic.init(critic_key, envs.single_observation_space.sample()), - ) - actor.apply = jax.jit(actor.apply) - critic.apply = jax.jit(critic.apply) - def linear_schedule(count): # anneal learning rate linearly after one training iteration which contains # (args.num_minibatches * args.update_epochs) gradient updates frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates return args.learning_rate * frac - agent_optimizer = optax.chain( - optax.clip_by_global_norm(args.max_grad_norm), - optax.inject_hyperparams(optax.adam)(learning_rate=linear_schedule, eps=1e-5), + + actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) + critic = Critic() + agent_state = TrainState.create( + apply_fn=None, + params=AgentParams( + actor.init(actor_key, envs.single_observation_space.sample()), + critic.init(critic_key, envs.single_observation_space.sample()), + ), + tx=optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.inject_hyperparams(optax.adam)(learning_rate=linear_schedule, eps=1e-5), + ), ) - agent_optimizer_state = agent_optimizer.init(agent_params) + actor.apply = jax.jit(actor.apply) + critic.apply = jax.jit(critic.apply) # ALGO Logic: Storage setup obs = jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape) @@ -224,35 +227,56 @@ def linear_schedule(count): avg_returns = deque(maxlen=20) @jax.jit - def get_action_and_value(x, d, obs, dones, actions, logprobs, values, step, agent_params, key): + def get_action_and_value( + agent_state: TrainState, + x: np.ndarray, + d: np.ndarray, + obs: np.ndarray, + dones: np.ndarray, + actions: np.ndarray, + logprobs: np.ndarray, + values: np.ndarray, + step: int, + key: jax.random.PRNGKey, + ): obs = obs.at[step].set(x) # inside jit() `x = x.at[idx].set(y)` is in-place. dones = dones.at[step].set(d) - action_mean, action_logstd = actor.apply(agent_params.actor_params, x) - # action_logstd = (jnp.ones_like(action_mean) * action_logstd) + action_mean, action_logstd = actor.apply(agent_state.params.actor_params, x) action_std = jnp.exp(action_logstd) key, subkey = jax.random.split(key) action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape) logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd - entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e) - value = critic.apply(agent_params.critic_params, x) + value = critic.apply(agent_state.params.critic_params, x) actions = actions.at[step].set(action) logprobs = logprobs.at[step].set(logprob.sum(1)) values = values.at[step].set(value.squeeze()) - return obs, dones, actions, logprobs, values, action, logprob, entropy, value, key + return obs, dones, actions, logprobs, values, action, key @jax.jit - def get_action_and_value2(x, action, agent_params): - action_mean, action_logstd = actor.apply(agent_params.actor_params, x) + def get_action_and_value2( + params: flax.core.FrozenDict, + x: np.ndarray, + action: np.ndarray, + ): + action_mean, action_logstd = actor.apply(params.actor_params, x) action_std = jnp.exp(action_logstd) logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e) - value = critic.apply(agent_params.critic_params, x).squeeze() + value = critic.apply(params.critic_params, x).squeeze() return logprob.sum(1), entropy, value @jax.jit - def compute_gae(next_obs, next_done, rewards, dones, values, advantages, agent_params): + def compute_gae( + agent_state: TrainState, + next_obs: np.ndarray, + next_done: np.ndarray, + rewards: np.ndarray, + dones: np.ndarray, + values: np.ndarray, + advantages: np.ndarray, + ): advantages = advantages.at[:].set(0.0) # reset advantages - next_value = critic.apply(agent_params.critic_params, next_obs).squeeze() + next_value = critic.apply(agent_state.params.critic_params, next_obs).squeeze() lastgaelam = 0 for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: @@ -268,7 +292,16 @@ def compute_gae(next_obs, next_done, rewards, dones, values, advantages, agent_p return jax.lax.stop_gradient(advantages), jax.lax.stop_gradient(returns) @jax.jit - def update_ppo(obs, logprobs, actions, advantages, returns, values, agent_params, agent_optimizer_state, key): + def update_ppo( + agent_state: TrainState, + obs: np.ndarray, + logprobs: np.ndarray, + actions: np.ndarray, + advantages: np.ndarray, + returns: np.ndarray, + values: np.ndarray, + key: jax.random.PRNGKey, + ): b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) b_logprobs = logprobs.reshape(-1) b_actions = actions.reshape((-1,) + envs.single_action_space.shape) @@ -276,8 +309,8 @@ def update_ppo(obs, logprobs, actions, advantages, returns, values, agent_params b_returns = returns.reshape(-1) values.reshape(-1) - def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): - newlogprob, _, newvalue = get_action_and_value2(x, a, agent_params) + def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): + newlogprob, _, newvalue = get_action_and_value2(params, x, a) logratio = newlogprob - logp ratio = jnp.exp(logratio) approx_kl = ((ratio - 1) - logratio).mean() @@ -308,17 +341,15 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): end = start + args.minibatch_size mb_inds = b_inds[start:end] (loss, (pg_loss, v_loss, approx_kl)), grads = ppo_loss_grad_fn( - agent_params, + agent_state.params, b_obs[mb_inds], b_actions[mb_inds], b_logprobs[mb_inds], b_advantages[mb_inds], b_returns[mb_inds], ) - updates, agent_optimizer_state = agent_optimizer.update(grads, agent_optimizer_state) - agent_params = optax.apply_updates(agent_params, updates) - - return loss, pg_loss, v_loss, approx_kl, key, agent_params, agent_optimizer_state + agent_state = agent_state.apply_gradients(grads=grads) + return agent_state, loss, pg_loss, v_loss, approx_kl, key # TRY NOT TO MODIFY: start the game global_step = 0 @@ -329,12 +360,12 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): for update in range(1, args.num_updates + 1): for step in range(0, args.num_steps): global_step += 1 * args.num_envs - obs, dones, actions, logprobs, values, action, logprob, entropy, value, key = get_action_and_value( - next_obs, next_done, obs, dones, actions, logprobs, values, step, agent_params, key + obs, dones, actions, logprobs, values, action, key = get_action_and_value( + agent_state, next_obs, next_done, obs, dones, actions, logprobs, values, step, key ) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, reward, next_done, info = envs.step(np.array(action)) + next_obs, rewards[step], next_done, info = envs.step(np.array(action)) for idx, d in enumerate(next_done): if d: print(f"global_step={global_step}, episodic_return={info['r'][idx]}") @@ -342,23 +373,21 @@ def ppo_loss(agent_params, x, a, logp, mb_advantages, mb_returns): writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) - rewards[step] = reward - advantages, returns = compute_gae(next_obs, next_done, rewards, dones, values, advantages, agent_params) - loss, pg_loss, v_loss, approx_kl, key, agent_params, agent_optimizer_state = update_ppo( + advantages, returns = compute_gae(agent_state, next_obs, next_done, rewards, dones, values, advantages) + agent_state, loss, pg_loss, v_loss, approx_kl, key = update_ppo( + agent_state, obs, logprobs, actions, advantages, returns, values, - agent_params, - agent_optimizer_state, key, ) # # TRY NOT TO MODIFY: record rewards for plotting purposes - writer.add_scalar("charts/learning_rate", agent_optimizer_state[1].hyperparams["learning_rate"].item(), global_step) + writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) # writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) From c411487fe06a2f58bf8d6590d97f6af8bb4e49d6 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 7 Jul 2022 14:56:39 -0400 Subject: [PATCH 24/27] `values` is not used --- cleanrl/ppo_continuous_action_envpool_jax.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index 40fde1565..f142e1c44 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -299,7 +299,6 @@ def update_ppo( actions: np.ndarray, advantages: np.ndarray, returns: np.ndarray, - values: np.ndarray, key: jax.random.PRNGKey, ): b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) @@ -307,7 +306,6 @@ def update_ppo( b_actions = actions.reshape((-1,) + envs.single_action_space.shape) b_advantages = advantages.reshape(-1) b_returns = returns.reshape(-1) - values.reshape(-1) def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): newlogprob, _, newvalue = get_action_and_value2(params, x, a) @@ -382,7 +380,6 @@ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): actions, advantages, returns, - values, key, ) From e27c81aa331b4be156aa5a750fbe0b538c8f1079 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 15 Jul 2022 17:00:11 -0400 Subject: [PATCH 25/27] refactor --- cleanrl/ppo_continuous_action_envpool_jax.py | 118 ++++++++++--------- 1 file changed, 63 insertions(+), 55 deletions(-) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index f142e1c44..eb300e074 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -149,6 +149,17 @@ class AgentParams: critic_params: flax.core.FrozenDict +@flax.struct.dataclass +class Storage: + obs: jnp.array + actions: jnp.array + logprobs: jnp.array + dones: jnp.array + values: jnp.array + advantages: jnp.array + returns: jnp.array + + if __name__ == "__main__": args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" @@ -217,40 +228,40 @@ def linear_schedule(count): critic.apply = jax.jit(critic.apply) # ALGO Logic: Storage setup - obs = jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape) - actions = jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape) - logprobs = jnp.zeros((args.num_steps, args.num_envs)) - rewards = np.zeros((args.num_steps, args.num_envs)) - dones = jnp.zeros((args.num_steps, args.num_envs)) - values = jnp.zeros((args.num_steps, args.num_envs)) - advantages = jnp.zeros((args.num_steps, args.num_envs)) - avg_returns = deque(maxlen=20) + storage = Storage( + obs=jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape), + actions=jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape), + logprobs=jnp.zeros((args.num_steps, args.num_envs)), + dones=jnp.zeros((args.num_steps, args.num_envs)), + values=jnp.zeros((args.num_steps, args.num_envs)), + advantages=jnp.zeros((args.num_steps, args.num_envs)), + returns=jnp.zeros((args.num_steps, args.num_envs)), + ) + rewards=np.zeros((args.num_steps, args.num_envs)) @jax.jit def get_action_and_value( agent_state: TrainState, - x: np.ndarray, - d: np.ndarray, - obs: np.ndarray, - dones: np.ndarray, - actions: np.ndarray, - logprobs: np.ndarray, - values: np.ndarray, + next_obs: np.ndarray, + next_done: np.ndarray, + storage: Storage, step: int, key: jax.random.PRNGKey, ): - obs = obs.at[step].set(x) # inside jit() `x = x.at[idx].set(y)` is in-place. - dones = dones.at[step].set(d) - action_mean, action_logstd = actor.apply(agent_state.params.actor_params, x) + action_mean, action_logstd = actor.apply(agent_state.params.actor_params, next_obs) action_std = jnp.exp(action_logstd) key, subkey = jax.random.split(key) action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape) logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd - value = critic.apply(agent_state.params.critic_params, x) - actions = actions.at[step].set(action) - logprobs = logprobs.at[step].set(logprob.sum(1)) - values = values.at[step].set(value.squeeze()) - return obs, dones, actions, logprobs, values, action, key + value = critic.apply(agent_state.params.critic_params, next_obs) + storage = storage.replace( + obs=storage.obs.at[step].set(next_obs), + dones=storage.dones.at[step].set(next_done), + actions=storage.actions.at[step].set(action), + logprobs=storage.logprobs.at[step].set(logprob.sum(1)), + values=storage.values.at[step].set(value.squeeze()), + ) + return storage, action, key @jax.jit def get_action_and_value2( @@ -270,12 +281,12 @@ def compute_gae( agent_state: TrainState, next_obs: np.ndarray, next_done: np.ndarray, - rewards: np.ndarray, - dones: np.ndarray, - values: np.ndarray, - advantages: np.ndarray, + rewards: np.ndarray, + storage: Storage, ): - advantages = advantages.at[:].set(0.0) # reset advantages + storage = storage.replace( + advantages=storage.advantages.at[:].set(0.0) + ) next_value = critic.apply(agent_state.params.critic_params, next_obs).squeeze() lastgaelam = 0 for t in reversed(range(args.num_steps)): @@ -283,29 +294,29 @@ def compute_gae( nextnonterminal = 1.0 - next_done nextvalues = next_value else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + nextnonterminal = 1.0 - storage.dones[t + 1] + nextvalues = storage.values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - storage.values[t] lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - advantages = advantages.at[t].set(lastgaelam) - returns = advantages + values - return jax.lax.stop_gradient(advantages), jax.lax.stop_gradient(returns) + storage = storage.replace( + advantages=storage.advantages.at[t].set(lastgaelam) + ) + storage = storage.replace( + returns=storage.advantages + storage.values + ) + return storage @jax.jit def update_ppo( agent_state: TrainState, - obs: np.ndarray, - logprobs: np.ndarray, - actions: np.ndarray, - advantages: np.ndarray, - returns: np.ndarray, + storage: Storage, key: jax.random.PRNGKey, ): - b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) - b_logprobs = logprobs.reshape(-1) - b_actions = actions.reshape((-1,) + envs.single_action_space.shape) - b_advantages = advantages.reshape(-1) - b_returns = returns.reshape(-1) + b_obs = storage.obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = storage.logprobs.reshape(-1) + b_actions = storage.actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = storage.advantages.reshape(-1) + b_returns = storage.returns.reshape(-1) def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): newlogprob, _, newvalue = get_action_and_value2(params, x, a) @@ -356,10 +367,11 @@ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): next_done = np.zeros(args.num_envs) for update in range(1, args.num_updates + 1): + update_time_start = time.time() for step in range(0, args.num_steps): global_step += 1 * args.num_envs - obs, dones, actions, logprobs, values, action, key = get_action_and_value( - agent_state, next_obs, next_done, obs, dones, actions, logprobs, values, step, key + storage, action, key = get_action_and_value( + agent_state, next_obs, next_done, storage, step, key ) # TRY NOT TO MODIFY: execute the game and log data. @@ -367,23 +379,17 @@ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): for idx, d in enumerate(next_done): if d: print(f"global_step={global_step}, episodic_return={info['r'][idx]}") - avg_returns.append(info["r"][idx]) - writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) - advantages, returns = compute_gae(agent_state, next_obs, next_done, rewards, dones, values, advantages) + storage = compute_gae(agent_state, next_obs, next_done, rewards, storage) agent_state, loss, pg_loss, v_loss, approx_kl, key = update_ppo( agent_state, - obs, - logprobs, - actions, - advantages, - returns, + storage, key, ) - # # TRY NOT TO MODIFY: record rewards for plotting purposes + # # # TRY NOT TO MODIFY: record rewards for plotting purposes writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) @@ -393,6 +399,8 @@ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): # writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) writer.add_scalar("losses/loss", loss.item(), global_step) print("SPS:", int(global_step / (time.time() - start_time))) + # print("update time:", time.time() - update_time_start) + writer.add_scalar("charts/update_time", time.time() - update_time_start, global_step) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() From 4b8e96b6aab278e8e9905c06bea27525f650fffb Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 15 Jul 2022 17:08:58 -0400 Subject: [PATCH 26/27] add seed --- cleanrl/ppo_continuous_action_envpool_jax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cleanrl/ppo_continuous_action_envpool_jax.py b/cleanrl/ppo_continuous_action_envpool_jax.py index eb300e074..acc1ae926 100644 --- a/cleanrl/ppo_continuous_action_envpool_jax.py +++ b/cleanrl/ppo_continuous_action_envpool_jax.py @@ -192,6 +192,7 @@ class Storage: args.env_id, env_type="gym", num_envs=args.num_envs, + seed=args.seed, ) envs.num_envs = args.num_envs envs.single_action_space = envs.action_space From 35603719d9d2db07cc3be3c796448cbc81a1a4d7 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 12 Jan 2023 19:45:18 -0500 Subject: [PATCH 27/27] Adds two variants that uses jax.scan --- ...ontinuous_action_envpool_async_jax_scan.py | 581 ++++++++++++++++++ ..._continuous_action_envpool_xla_jax_scan.py | 478 ++++++++++++++ 2 files changed, 1059 insertions(+) create mode 100644 cleanrl/ppo_continuous_action_envpool_async_jax_scan.py create mode 100644 cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py diff --git a/cleanrl/ppo_continuous_action_envpool_async_jax_scan.py b/cleanrl/ppo_continuous_action_envpool_async_jax_scan.py new file mode 100644 index 000000000..792ed56ca --- /dev/null +++ b/cleanrl/ppo_continuous_action_envpool_async_jax_scan.py @@ -0,0 +1,581 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy +# https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ +import argparse +import os +import random +import time +from distutils.util import strtobool +from typing import Sequence + +os.environ[ + "XLA_PYTHON_CLIENT_MEM_FRACTION" +] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 + +import envpool +import flax +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.linen.initializers import constant, orthogonal +from flax.training.train_state import TrainState +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to save model into the `runs/{run_name}` folder") + parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to upload the saved model to huggingface") + parser.add_argument("--hf-entity", type=str, default="", + help="the user or org name of the model repository from the Hugging Face Hub") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="Ant-v4", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=20000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=0.00295, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=128, + help="the number of parallel game environments") + parser.add_argument("--async-batch-size", type=int, default=32, + help="the envpool's batch size in the async mode") + parser.add_argument("--num-steps", type=int, default=64, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Use GAE for advantage computation") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=2, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=2, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.0, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=1.3, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=3.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_updates = args.total_timesteps // args.batch_size + # fmt: on + return args + + +def make_env(env_id, seed, num_envs, async_batch_size=1): + def thunk(): + envs = envpool.make( + env_id, + env_type="gym", + num_envs=num_envs, + batch_size=async_batch_size, + seed=seed, + ) + envs = gym.wrappers.FlattenObservation(envs) # deal with dm_control's Dict observation space + envs.num_envs = num_envs + envs.single_action_space = envs.action_space + envs.single_observation_space = envs.observation_space + envs.is_vector_env = True + return envs + + return thunk + + + +class Critic(nn.Module): + @nn.compact + def __call__(self, x): + critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) + critic = nn.tanh(critic) + critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(critic) + critic = nn.tanh(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(critic) + return critic + + +class Actor(nn.Module): + action_dim: Sequence[int] + + @nn.compact + def __call__(self, x): + actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) + actor_mean = nn.tanh(actor_mean) + actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(actor_mean) + actor_mean = nn.tanh(actor_mean) + actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean) + actor_logstd = self.param("actor_logstd", constant(0.0), (1, self.action_dim)) + return actor_mean, actor_logstd + + +@flax.struct.dataclass +class AgentParams: + actor_params: flax.core.FrozenDict + critic_params: flax.core.FrozenDict + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, actor_key, critic_key = jax.random.split(key, 3) + + # env setup + envs = make_env(args.env_id, args.seed, args.num_envs, args.async_batch_size)() + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + def linear_schedule(count): + # anneal learning rate linearly after one training iteration which contains + # (args.num_minibatches * args.update_epochs) gradient updates + frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates + return args.learning_rate * frac + + actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) + critic = Critic() + agent_state = TrainState.create( + apply_fn=None, + params=AgentParams( + actor.init(actor_key, envs.single_observation_space.sample()), + critic.init(critic_key, envs.single_observation_space.sample()), + ), + tx=optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.inject_hyperparams(optax.adam)( + learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 + ), + ), + ) + actor.apply = jax.jit(actor.apply) + critic.apply = jax.jit(critic.apply) + + @jax.jit + def get_action_and_value( + agent_state: TrainState, + next_obs: np.ndarray, + key: jax.random.PRNGKey, + ): + """sample action, calculate value, logprob, entropy, and update storage""" + action_mean, action_logstd = actor.apply(agent_state.params.actor_params, next_obs) + action_std = jnp.exp(action_logstd) + key, subkey = jax.random.split(key) + action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape) + logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd + value = critic.apply(agent_state.params.critic_params, next_obs) + return action, logprob.sum(1), value.squeeze(1), key + + @jax.jit + def get_action_and_value2( + params: flax.core.FrozenDict, + x: np.ndarray, + action: np.ndarray, + ): + """calculate value, logprob of supplied `action`, and entropy""" + action_mean, action_logstd = actor.apply(params.actor_params, x) + action_std = jnp.exp(action_logstd) + logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd + entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e) + value = critic.apply(params.critic_params, x).squeeze() + return logprob.sum(1), entropy, value + + def compute_gae_once(carry, x): + lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked = carry + ( + done, + value, + eid, + reward, + ) = x + nextnonterminal = 1.0 - lastdones[eid] + nextvalues = lastvalues[eid] + delta = jnp.where(final_env_id_checked[eid] == -1, 0, reward + args.gamma * nextvalues * nextnonterminal - value) + advantages = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam[eid] + final_env_ids = jnp.where(final_env_id_checked[eid] == 1, 1, 0) + final_env_id_checked = final_env_id_checked.at[eid].set( + jnp.where(final_env_id_checked[eid] == -1, 1, final_env_id_checked[eid]) + ) + + # the last_ variables keeps track of the actual `num_steps` + lastgaelam = lastgaelam.at[eid].set(advantages) + lastdones = lastdones.at[eid].set(done) + lastvalues = lastvalues.at[eid].set(value) + return (lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked), ( + advantages, + final_env_ids, + ) + + @jax.jit + def compute_gae( + env_ids: np.ndarray, + rewards: np.ndarray, + values: np.ndarray, + dones: np.ndarray, + ): + dones = jnp.asarray(dones) + values = jnp.asarray(values) + env_ids = jnp.asarray(env_ids) + rewards = jnp.asarray(rewards) + + _, B = env_ids.shape + final_env_id_checked = jnp.zeros(args.num_envs, jnp.int32) - 1 + final_env_ids = jnp.zeros(B, jnp.int32) + advantages = jnp.zeros(B) + lastgaelam = jnp.zeros(args.num_envs) + lastdones = jnp.zeros(args.num_envs) + 1 + lastvalues = jnp.zeros(args.num_envs) + + (_, _, _, _, final_env_ids, final_env_id_checked), (advantages, final_env_ids) = jax.lax.scan( + compute_gae_once, + ( + lastvalues, + lastdones, + advantages, + lastgaelam, + final_env_ids, + final_env_id_checked, + ), + ( + dones, + values, + env_ids, + rewards, + ), + reverse=True, + ) + return advantages, advantages + values, final_env_id_checked, final_env_ids + + def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): + newlogprob, entropy, newvalue = get_action_and_value2(params, x, a) + logratio = newlogprob - logp + ratio = jnp.exp(logratio) + approx_kl = ((ratio - 1) - logratio).mean() + + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() + + # Value loss + v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) + + ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) + + @jax.jit + def update_ppo( + agent_state: TrainState, + obs: list, + dones: list, + values: list, + actions: list, + logprobs: list, + env_ids: list, + rewards: list, + key: jax.random.PRNGKey, + ): + obs = jnp.asarray(obs) + dones = jnp.asarray(dones) + values = jnp.asarray(values) + actions = jnp.asarray(actions) + logprobs = jnp.asarray(logprobs) + env_ids = jnp.asarray(env_ids) + rewards = jnp.asarray(rewards) + + # TODO: in an unlikely event, one of the envs might have not stepped at all, which may results in unexpected behavior + T, B = env_ids.shape + index_ranges = jnp.arange(T * B, dtype=jnp.int32) + next_index_ranges = jnp.zeros_like(index_ranges, dtype=jnp.int32) + last_env_ids = jnp.zeros(args.num_envs, dtype=jnp.int32) - 1 + + def f(carry, x): + last_env_ids, next_index_ranges = carry + env_id, index_range = x + next_index_ranges = next_index_ranges.at[last_env_ids[env_id]].set( + jnp.where(last_env_ids[env_id] != -1, index_range, next_index_ranges[last_env_ids[env_id]]) + ) + last_env_ids = last_env_ids.at[env_id].set(index_range) + return (last_env_ids, next_index_ranges), None + + (last_env_ids, next_index_ranges), _ = jax.lax.scan( + f, + (last_env_ids, next_index_ranges), + (env_ids.reshape(-1), index_ranges), + ) + + # rewards is off by one time step + rewards = rewards.reshape(-1)[next_index_ranges].reshape((args.num_steps) * async_update, args.async_batch_size) + advantages, returns, _, final_env_ids = compute_gae(env_ids, rewards, values, dones) + b_inds = jnp.nonzero(final_env_ids.reshape(-1), size=(args.num_steps) * async_update * args.async_batch_size)[0] + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_logprobs = logprobs.reshape(-1) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + + def update_epoch(carry, _): + agent_state, key = carry + key, subkey = jax.random.split(key) + + # taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py + def convert_data(x: jnp.ndarray): + x = jax.random.permutation(subkey, x) + x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:]) + return x + + def update_minibatch(agent_state, minibatch): + mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns = minibatch + (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( + agent_state.params, + mb_obs, + mb_actions, + mb_logprobs, + mb_advantages, + mb_returns, + ) + agent_state = agent_state.apply_gradients(grads=grads) + return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) + + agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( + update_minibatch, + agent_state, + ( + convert_data(b_obs), + convert_data(b_actions), + convert_data(b_logprobs), + convert_data(b_advantages), + convert_data(b_returns), + ), + ) + return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) + + (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, _) = jax.lax.scan( + update_epoch, (agent_state, key), (), length=args.update_epochs + ) + return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, advantages, returns, b_inds, final_env_ids, key + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + async_update = int(args.num_envs / args.async_batch_size) + + # put data in the last index + episode_returns = np.zeros((args.num_envs,), dtype=np.float32) + returned_episode_returns = np.zeros((args.num_envs,), dtype=np.float32) + episode_lengths = np.zeros((args.num_envs,), dtype=np.float32) + returned_episode_lengths = np.zeros((args.num_envs,), dtype=np.float32) + envs.async_reset() + final_env_ids = np.zeros((async_update, args.async_batch_size), dtype=np.int32) + + for update in range(1, args.num_updates + 2): + update_time_start = time.time() + obs = [] + dones = [] + actions = [] + logprobs = [] + values = [] + env_ids = [] + rewards = [] + truncations = [] + terminations = [] + env_recv_time = 0 + inference_time = 0 + storage_time = 0 + env_send_time = 0 + + # NOTE: This is a major difference from the sync version: + # at the end of the rollout phase, the sync version will have the next observation + # ready for the value bootstrap, but the async version will not have it. + # for this reason we do `num_steps + 1`` to get the extra states for value bootstrapping. + # but note that the extra states are not used for the loss computation in the next iteration, + # while the sync version will use the extra state for the loss computation. + for step in range( + async_update, (args.num_steps + 1) * async_update + ): # num_steps + 1 to get the states for value bootstrapping. + env_recv_time_start = time.time() + next_obs, next_reward, next_done, info = envs.recv() + if type(next_obs) == dict: # support dict observations + next_obs = np.concatenate(list(next_obs.values()), -1) + env_recv_time += time.time() - env_recv_time_start + global_step += len(next_done) + env_id = info["env_id"] + + inference_time_start = time.time() + action, logprob, value, key = get_action_and_value(agent_state, next_obs, key) + inference_time += time.time() - inference_time_start + + env_send_time_start = time.time() + envs.send(np.array(action), env_id) + env_send_time += time.time() - env_send_time_start + storage_time_start = time.time() + obs.append(next_obs) + dones.append(next_done) + values.append(value) + actions.append(action) + logprobs.append(logprob) + env_ids.append(env_id) + rewards.append(next_reward) + truncations.append(info["TimeLimit.truncated"]) + terminations.append(next_done) + episode_returns[env_id] += next_reward + returned_episode_returns[env_id] = np.where( + next_done + info["TimeLimit.truncated"], episode_returns[env_id], returned_episode_returns[env_id] + ) + episode_returns[env_id] *= (1 - next_done) * (1 - info["TimeLimit.truncated"]) + episode_lengths[env_id] += 1 + returned_episode_lengths[env_id] = np.where( + next_done + info["TimeLimit.truncated"], episode_lengths[env_id], returned_episode_lengths[env_id] + ) + episode_lengths[env_id] *= (1 - next_done) * (1 - info["TimeLimit.truncated"]) + storage_time += time.time() - storage_time_start + + avg_episodic_return = np.mean(returned_episode_returns) + # print(returned_episode_returns) + print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}") + writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) + writer.add_scalar("charts/avg_episodic_length", np.mean(returned_episode_lengths), global_step) + training_time_start = time.time() + ( + agent_state, + loss, + pg_loss, + v_loss, + entropy_loss, + approx_kl, + advantages, + returns, + b_inds, + final_env_ids, + key, + ) = update_ppo( + agent_state, + obs, + dones, + values, + actions, + logprobs, + env_ids, + rewards, + key, + ) + writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step) + # writer.add_scalar("stats/advantages", advantages.mean().item(), global_step) + # writer.add_scalar("stats/returns", returns.mean().item(), global_step) + writer.add_scalar("stats/truncations", np.sum(truncations), global_step) + writer.add_scalar("stats/terminations", np.sum(terminations), global_step) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step) + writer.add_scalar("losses/value_loss", v_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl[-1, -1].item(), global_step) + writer.add_scalar("losses/loss", loss[-1, -1].item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar( + "charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step + ) + writer.add_scalar("stats/env_recv_time", env_recv_time, global_step) + writer.add_scalar("stats/inference_time", inference_time, global_step) + writer.add_scalar("stats/storage_time", storage_time, global_step) + writer.add_scalar("stats/env_send_time", env_send_time, global_step) + writer.add_scalar("stats/update_time", time.time() - update_time_start, global_step) + + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" + with open(model_path, "wb") as f: + f.write( + flax.serialization.to_bytes( + [ + vars(args), + [ + agent_state.params.network_params, + agent_state.params.actor_params, + agent_state.params.critic_params, + ], + ] + ) + ) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=(Network, Actor, Critic), + ) + for idx, episodic_return in enumerate(episodic_returns): + writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") + + envs.close() + writer.close() diff --git a/cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py b/cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py new file mode 100644 index 000000000..674821a89 --- /dev/null +++ b/cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py @@ -0,0 +1,478 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy +import argparse +import os +import random +import time +from distutils.util import strtobool +from functools import partial +from typing import Sequence + +os.environ[ + "XLA_PYTHON_CLIENT_MEM_FRACTION" +] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 + +import envpool +import flax +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.linen.initializers import constant, orthogonal +from flax.training.train_state import TrainState +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to save model into the `runs/{run_name}` folder") + parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to upload the saved model to huggingface") + parser.add_argument("--hf-entity", type=str, default="", + help="the user or org name of the model repository from the Hugging Face Hub") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="Ant-v4", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=10000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=0.00295, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=64, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=64, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Use GAE for advantage computation") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=4, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=2, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.0, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=1.3, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=3.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_updates = args.total_timesteps // args.batch_size + # fmt: on + return args + + + +def make_env(env_id, seed, num_envs): + def thunk(): + envs = envpool.make( + env_id, + env_type="gym", + num_envs=num_envs, + seed=seed, + ) + envs.num_envs = num_envs + envs.single_action_space = envs.action_space + envs.single_observation_space = envs.observation_space + envs.is_vector_env = True + return envs + + return thunk + + + +class Critic(nn.Module): + @nn.compact + def __call__(self, x): + critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) + critic = nn.tanh(critic) + critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(critic) + critic = nn.tanh(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(critic) + return critic + + +class Actor(nn.Module): + action_dim: Sequence[int] + + @nn.compact + def __call__(self, x): + actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) + actor_mean = nn.tanh(actor_mean) + actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(actor_mean) + actor_mean = nn.tanh(actor_mean) + actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean) + actor_logstd = self.param("actor_logstd", constant(0.0), (1, self.action_dim)) + return actor_mean, actor_logstd + + +@flax.struct.dataclass +class AgentParams: + actor_params: flax.core.FrozenDict + critic_params: flax.core.FrozenDict + + +@flax.struct.dataclass +class Storage: + obs: jnp.array + actions: jnp.array + logprobs: jnp.array + dones: jnp.array + values: jnp.array + advantages: jnp.array + returns: jnp.array + rewards: jnp.array + + +@flax.struct.dataclass +class EpisodeStatistics: + episode_returns: jnp.array + episode_lengths: jnp.array + returned_episode_returns: jnp.array + returned_episode_lengths: jnp.array + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, actor_key, critic_key = jax.random.split(key, 3) + + # env setup + envs = make_env(args.env_id, args.seed, args.num_envs)() + episode_stats = EpisodeStatistics( + episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), + episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), + returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), + returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), + ) + handle, recv, send, step_env = envs.xla() + + def step_env_wrappeed(episode_stats, handle, action): + # print(type(action.astype(jnp.float64))) + handle, (next_obs, reward, next_done, info) = step_env(handle, action.astype(jnp.float64)) + new_episode_return = episode_stats.episode_returns + reward + new_episode_length = episode_stats.episode_lengths + 1 + episode_stats = episode_stats.replace( + episode_returns=(new_episode_return) * (1 - next_done) * (1 - info["TimeLimit.truncated"]), + episode_lengths=(new_episode_length) * (1 - next_done) * (1 - info["TimeLimit.truncated"]), + # only update the `returned_episode_returns` if the episode is done + returned_episode_returns=jnp.where( + next_done + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns + ), + returned_episode_lengths=jnp.where( + next_done + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths + ), + ) + return episode_stats, handle, (next_obs, reward, next_done, info) + + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + def linear_schedule(count): + # anneal learning rate linearly after one training iteration which contains + # (args.num_minibatches * args.update_epochs) gradient updates + frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates + return args.learning_rate * frac + + actor = Actor(action_dim=np.prod(envs.single_action_space.shape)) + critic = Critic() + agent_state = TrainState.create( + apply_fn=None, + params=AgentParams( + actor.init(actor_key, envs.single_observation_space.sample()), + critic.init(critic_key, envs.single_observation_space.sample()), + ), + tx=optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.inject_hyperparams(optax.adam)( + learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 + ), + ), + ) + actor.apply = jax.jit(actor.apply) + critic.apply = jax.jit(critic.apply) + + @jax.jit + def get_action_and_value( + agent_state: TrainState, + next_obs: np.ndarray, + key: jax.random.PRNGKey, + ): + """sample action, calculate value, logprob, entropy, and update storage""" + action_mean, action_logstd = actor.apply(agent_state.params.actor_params, next_obs) + action_std = jnp.exp(action_logstd) + key, subkey = jax.random.split(key) + action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape) + logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd + value = critic.apply(agent_state.params.critic_params, next_obs) + return action, logprob.sum(1), value.squeeze(1), key + + @jax.jit + def get_action_and_value2( + params: flax.core.FrozenDict, + x: np.ndarray, + action: np.ndarray, + ): + """calculate value, logprob of supplied `action`, and entropy""" + action_mean, action_logstd = actor.apply(params.actor_params, x) + action_std = jnp.exp(action_logstd) + logprob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd + entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e) + value = critic.apply(params.critic_params, x).squeeze() + return logprob.sum(1), entropy, value + + def compute_gae_once(carry, inp, gamma, gae_lambda): + advantages = carry + nextdone, nextvalues, curvalues, reward = inp + nextnonterminal = 1.0 - nextdone + + delta = reward + gamma * nextvalues * nextnonterminal - curvalues + advantages = delta + gamma * gae_lambda * nextnonterminal * advantages + return advantages, advantages + + compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda) + + @jax.jit + def compute_gae( + agent_state: TrainState, + next_obs: np.ndarray, + next_done: np.ndarray, + storage: Storage, + ): + next_value = critic.apply(agent_state.params.critic_params, next_obs).squeeze() + + advantages = jnp.zeros((args.num_envs,)) + dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0) + values = jnp.concatenate([storage.values, next_value[None, :]], axis=0) + _, advantages = jax.lax.scan( + compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True + ) + storage = storage.replace( + advantages=advantages, + returns=advantages + storage.values, + ) + return storage + + def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): + newlogprob, entropy, newvalue = get_action_and_value2(params, x, a) + logratio = newlogprob - logp + ratio = jnp.exp(logratio) + approx_kl = ((ratio - 1) - logratio).mean() + + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() + + # Value loss + v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) + + ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) + + @jax.jit + def update_ppo( + agent_state: TrainState, + storage: Storage, + key: jax.random.PRNGKey, + ): + def update_epoch(carry, unused_inp): + agent_state, key = carry + key, subkey = jax.random.split(key) + + def flatten(x): + return x.reshape((-1,) + x.shape[2:]) + + # taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py + def convert_data(x: jnp.ndarray): + x = jax.random.permutation(subkey, x) + x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:]) + return x + + flatten_storage = jax.tree_map(flatten, storage) + shuffled_storage = jax.tree_map(convert_data, flatten_storage) + + def update_minibatch(agent_state, minibatch): + (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( + agent_state.params, + minibatch.obs, + minibatch.actions, + minibatch.logprobs, + minibatch.advantages, + minibatch.returns, + ) + agent_state = agent_state.apply_gradients(grads=grads) + return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) + + agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( + update_minibatch, agent_state, shuffled_storage + ) + return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) + + (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( + update_epoch, (agent_state, key), (), length=args.update_epochs + ) + return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = envs.reset() + next_done = jnp.zeros(args.num_envs, dtype=jax.numpy.bool_) + + # based on https://github.dev/google/evojax/blob/0625d875262011d8e1b6aa32566b236f44b4da66/evojax/sim_mgr.py + def step_once(carry, step, env_step_fn): + agent_state, episode_stats, obs, done, key, handle = carry + action, logprob, value, key = get_action_and_value(agent_state, obs, key) + + episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action) + storage = Storage( + obs=obs, + actions=action, + logprobs=logprob, + dones=done, + values=value, + rewards=reward, + returns=jnp.zeros_like(reward), + advantages=jnp.zeros_like(reward), + ) + return ((agent_state, episode_stats, next_obs, next_done, key, handle), storage) + + def rollout(agent_state, episode_stats, next_obs, next_done, key, handle, step_once_fn, max_steps): + (agent_state, episode_stats, next_obs, next_done, key, handle), storage = jax.lax.scan( + step_once_fn, (agent_state, episode_stats, next_obs, next_done, key, handle), (), max_steps + ) + return agent_state, episode_stats, next_obs, next_done, storage, key, handle + + rollout = partial(rollout, step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed), max_steps=args.num_steps) + + for update in range(1, args.num_updates + 1): + update_time_start = time.time() + agent_state, episode_stats, next_obs, next_done, storage, key, handle = rollout( + agent_state, episode_stats, next_obs, next_done, key, handle + ) + global_step += args.num_steps * args.num_envs + storage = compute_gae(agent_state, next_obs, next_done, storage) + agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo( + agent_state, + storage, + key, + ) + avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns)) + print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}") + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) + writer.add_scalar( + "charts/avg_episodic_length", np.mean(jax.device_get(episode_stats.returned_episode_lengths)), global_step + ) + writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step) + writer.add_scalar("losses/value_loss", v_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl[-1, -1].item(), global_step) + writer.add_scalar("losses/loss", loss[-1, -1].item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar( + "charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step + ) + + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" + with open(model_path, "wb") as f: + f.write( + flax.serialization.to_bytes( + [ + vars(args), + [ + agent_state.params.network_params, + agent_state.params.actor_params, + agent_state.params.critic_params, + ], + ] + ) + ) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=(Network, Actor, Critic), + ) + for idx, episodic_return in enumerate(episodic_returns): + writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") + + envs.close() + writer.close()