From 2483a595be5cbc778047310d42be7ff977d99114 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Fri, 1 Mar 2024 15:10:29 -0800 Subject: [PATCH 01/17] flatten rgbd obs mode wrapper --- mani_skill2/envs/sapien_env.py | 12 ++++++----- mani_skill2/utils/common.py | 1 + mani_skill2/utils/sapien_utils.py | 17 +++++++++++++++ mani_skill2/utils/wrappers/flatten.py | 30 +++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 5 deletions(-) diff --git a/mani_skill2/envs/sapien_env.py b/mani_skill2/envs/sapien_env.py index f1912793d..7b045ef03 100644 --- a/mani_skill2/envs/sapien_env.py +++ b/mani_skill2/envs/sapien_env.py @@ -41,6 +41,7 @@ from mani_skill2.utils.sapien_utils import ( batch, get_obj_by_type, + to_cpu_tensor, to_numpy, to_tensor, unbatch, @@ -240,11 +241,10 @@ def __init__( else 0 ) obs, _ = self.reset(seed=2022, options=dict(reconfigure=True)) + self._init_raw_obs = to_cpu_tensor(obs) + """the raw observation returned by the env.reset (a cpu torch tensor/dict of tensors). Useful for future observation wrappers to use to auto generate observation spaces""" if physx.is_gpu_enabled(): obs = to_numpy(obs) - self._init_raw_obs = obs.copy() - """the raw observation returned by the env.reset. Useful for future observation wrappers to use to auto generate observation spaces""" - # TODO handle constructing single obs space from a batched result. self.action_space = self.agent.action_space self.single_action_space = self.agent.single_action_space @@ -264,9 +264,11 @@ def _update_obs_space(self, obs: Any): @cached_property def single_observation_space(self): if self.num_envs > 1: - return convert_observation_to_space(self._init_raw_obs, unbatched=True) + return convert_observation_to_space( + to_numpy(self._init_raw_obs), unbatched=True + ) else: - return convert_observation_to_space(self._init_raw_obs) + return convert_observation_to_space(to_numpy(self._init_raw_obs)) @cached_property def observation_space(self): diff --git a/mani_skill2/utils/common.py b/mani_skill2/utils/common.py index 0a068c1af..4b1c18340 100644 --- a/mani_skill2/utils/common.py +++ b/mani_skill2/utils/common.py @@ -190,6 +190,7 @@ def flatten_state_dict( Args: state_dict: a dictionary containing scalars or 1-dim vectors. + use_torch (bool): Whether to convert the data to torch tensors. Raises: AssertionError: If a value of @state_dict is an ndarray with ndim > 2. diff --git a/mani_skill2/utils/sapien_utils.py b/mani_skill2/utils/sapien_utils.py index 3a57f9a8c..9b2fe5b85 100644 --- a/mani_skill2/utils/sapien_utils.py +++ b/mani_skill2/utils/sapien_utils.py @@ -46,6 +46,23 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence]): return torch.tensor(array) +def to_cpu_tensor(array: Union[torch.Tensor, np.array, Sequence]): + """ + Maps any given sequence to a torch tensor on the CPU. + """ + if isinstance(array, (dict)): + return {k: to_tensor(v) for k, v in array.items()} + if isinstance(array, np.ndarray): + ret = torch.from_numpy(array) + if ret.dtype == torch.float64: + ret = ret.float() + return ret + elif isinstance(array, torch.Tensor): + return array.cpu() + else: + return torch.Tensor(array).cpu() + + def _to_numpy(array: Union[Array, Sequence]) -> np.ndarray: if isinstance(array, (dict)): return {k: _to_numpy(v) for k, v in array.items()} diff --git a/mani_skill2/utils/wrappers/flatten.py b/mani_skill2/utils/wrappers/flatten.py index 757d657e5..bbbe67cf5 100644 --- a/mani_skill2/utils/wrappers/flatten.py +++ b/mani_skill2/utils/wrappers/flatten.py @@ -1,7 +1,10 @@ import copy +from typing import Dict import gymnasium as gym import gymnasium.spaces.utils +import numpy as np +import torch from gymnasium.vector.utils import batch_space from mani_skill2.envs.sapien_env import BaseEnv @@ -9,6 +12,33 @@ from mani_skill2.utils.sapien_utils import batch +class FlattenRGBDObservationWrapper(gym.ObservationWrapper): + """ + Flattens the rgbd mode observations into a dictionary with two keys, "rgbd" and "state" + """ + + def __init__(self, env) -> None: + self.base_env: BaseEnv = env.unwrapped + super().__init__(env) + new_obs = self.observation(self.base_env._init_raw_obs) + import ipdb + + ipdb.set_trace() + self.base_env._update_obs_space(new_obs) + + def observation(self, observation: Dict): + sensor_data = observation.pop("sensor_data") + del observation["sensor_param"] + images = [] + for cam_data in sensor_data.values(): + images.append(cam_data["rgb"]) + images.append(cam_data["depth"]) + images = torch.concat(images, axis=-1) + # flatten the rest of the data which should just be state data + observation = flatten_state_dict(observation, use_torch=True) + return dict(state=observation, rgbd=images) + + class FlattenObservationWrapper(gym.ObservationWrapper): """ Flattens the observations into a single vector From 9d533f3b8a84fa7e9c7e4dae1e99db24300c5d8d Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Fri, 1 Mar 2024 15:10:41 -0800 Subject: [PATCH 02/17] Update flatten.py --- mani_skill2/utils/wrappers/flatten.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mani_skill2/utils/wrappers/flatten.py b/mani_skill2/utils/wrappers/flatten.py index bbbe67cf5..24c72bf4e 100644 --- a/mani_skill2/utils/wrappers/flatten.py +++ b/mani_skill2/utils/wrappers/flatten.py @@ -21,9 +21,6 @@ def __init__(self, env) -> None: self.base_env: BaseEnv = env.unwrapped super().__init__(env) new_obs = self.observation(self.base_env._init_raw_obs) - import ipdb - - ipdb.set_trace() self.base_env._update_obs_space(new_obs) def observation(self, observation: Dict): From 7feb2365d009885dffd9607f39ee5ab7933a729c Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Fri, 1 Mar 2024 17:54:52 -0800 Subject: [PATCH 03/17] rgb baseline --- examples/baselines/ppo/ppo_rgb.py | 506 ++++++++++++++++++++++++++ mani_skill2/utils/wrappers/flatten.py | 6 +- 2 files changed, 510 insertions(+), 2 deletions(-) create mode 100644 examples/baselines/ppo/ppo_rgb.py diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py new file mode 100644 index 000000000..ef8a4f072 --- /dev/null +++ b/examples/baselines/ppo/ppo_rgb.py @@ -0,0 +1,506 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy +import os +import random +import time +from dataclasses import dataclass + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import tyro +from torch.distributions.normal import Normal +from torch.utils.tensorboard import SummaryWriter + +# ManiSkill specific imports +import mani_skill2.envs +from mani_skill2.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper +from mani_skill2.utils.wrappers.record import RecordEpisode +from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = True + """whether to capture videos of the agent performances (check out `videos` folder)""" + save_model: bool = True + """whether to save model into the `runs/{run_name}` folder""" + upload_model: bool = False + """whether to upload the saved model to huggingface""" + hf_entity: str = "" + """the user or org name of the model repository from the Hugging Face Hub""" + + # Algorithm specific arguments + env_id: str = "PickCube-v1" + """the id of the environment""" + total_timesteps: int = 10000000 + """total timesteps of the experiments""" + learning_rate: float = 3e-4 + """the learning rate of the optimizer""" + num_envs: int = 512 + """the number of parallel environments""" + num_eval_envs: int = 8 + """the number of parallel evaluation environments""" + num_steps: int = 50 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = False + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.8 + """the discount factor gamma""" + gae_lambda: float = 0.9 + """the lambda for the general advantage estimation""" + num_minibatches: int = 32 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.2 + """the surrogate clipping coefficient""" + clip_vloss: bool = False + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.0 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = 0.1 + """the target KL divergence threshold""" + eval_freq: int = 25 + """evaluation frequency in terms of iterations""" + finite_horizon_gae: bool = True + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + +class DictArray(object): + def __init__(self, buffer_shape, element_space, data_dict=None, device=None): + self.buffer_shape = buffer_shape + if data_dict: + self.data = data_dict + else: + assert isinstance(element_space, gym.spaces.dict.Dict) + self.data = {} + for k, v in element_space.items(): + if isinstance(v, gym.spaces.dict.Dict): + self.data[k] = DictArray(buffer_shape, v) + else: + self.data[k] = torch.zeros(buffer_shape + v.shape).to(device) + + def keys(self): + return self.data.keys() + + def __getitem__(self, index): + if isinstance(index, str): + return self.data[index] + return { + k: v[index] for k, v in self.data.items() + } + + def __setitem__(self, index, value): + if isinstance(index, str): + self.data[index] = value + for k, v in value.items(): + self.data[k][index] = v + + @property + def shape(self): + return self.buffer_shape + + def reshape(self, shape): + t = len(self.buffer_shape) + new_dict = {} + for k,v in self.data.items(): + if isinstance(v, DictArray): + new_dict[k] = v.reshape(shape) + else: + new_dict[k] = v.reshape(shape + v.shape[t:]) + new_buffer_shape = next(iter(new_dict.values())).shape[:len(shape)] + return DictArray(new_buffer_shape, None, data_dict=new_dict) + +class NatureCNN(nn.Module): + def __init__(self, sample_obs): + super().__init__() + + extractors = {} + + self.out_features = 0 + feature_size = 256 + in_channels=sample_obs["rgbd"].shape[-1] + image_size=(sample_obs["rgbd"].shape[1], sample_obs["rgbd"].shape[2]) + state_size=sample_obs["state"].shape[-1] + + # here we use a NatureCNN architecture to process images, but any architecture is permissble here + cnn = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=32, + kernel_size=8, + stride=4, + padding=0, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0 + ), + nn.ReLU(), + nn.Flatten(), + ) + + # to easily figure out the dimensions after flattening, we pass a test tensor + with torch.no_grad(): + n_flatten = cnn(sample_obs["rgbd"].float().permute(0,3,1,2).cpu()).shape[1] + fc = nn.Sequential(nn.Linear(n_flatten, feature_size), nn.ReLU()) + extractors["rgbd"] = nn.Sequential(cnn, fc) + self.out_features += feature_size + + # for state data we simply pass it through a single linear layer + extractors["state"] = nn.Linear(state_size, 64) + self.out_features += 64 + + self.extractors = nn.ModuleDict(extractors) + + def forward(self, observations) -> torch.Tensor: + encoded_tensor_list = [] + # self.extractors contain nn.Modules that do all the processing. + for key, extractor in self.extractors.items(): + obs = observations[key] + if key == "rgbd": + obs = obs.float().permute(0,3,1,2) + obs = obs / 255 + encoded_tensor_list.append(extractor(obs)) + return torch.cat(encoded_tensor_list, dim=1) + +class Agent(nn.Module): + def __init__(self, envs, sample_obs): + super().__init__() + self.feature_net = NatureCNN(sample_obs=sample_obs) + # latent_size = np.array(envs.unwrapped.single_observation_space.shape).prod() + latent_size = self.feature_net.out_features + self.critic = nn.Sequential( + layer_init(nn.Linear(latent_size, 512)), + nn.Tanh(), + layer_init(nn.Linear(512, 1)), + ) + self.actor_mean = nn.Sequential( + layer_init(nn.Linear(latent_size, 512)), + nn.Tanh(), + layer_init(nn.Linear(512, np.prod(envs.unwrapped.single_action_space.shape)), std=0.01*np.sqrt(2)), + ) + self.actor_logstd = nn.Parameter(torch.ones(1, np.prod(envs.unwrapped.single_action_space.shape)) * -0.5) + def get_features(self, x): + return self.feature_net(x) + def get_value(self, x): + x = self.feature_net(x) + return self.critic(x) + def get_action(self, x, deterministic=False): + x = self.feature_net(x) + action_mean = self.actor_mean(x) + if deterministic: + return action_mean + action_logstd = self.actor_logstd.expand_as(action_mean) + action_std = torch.exp(action_logstd) + probs = Normal(action_mean, action_std) + return probs.sample() + def get_action_and_value(self, x, action=None): + x = self.feature_net(x) + action_mean = self.actor_mean(x) + action_logstd = self.actor_logstd.expand_as(action_mean) + action_std = torch.exp(action_logstd) + probs = Normal(action_mean, action_std) + if action is None: + action = probs.sample() + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) + + +if __name__ == "__main__": + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array") + envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs) + eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) + + # rgbd obs mode returns a dict of data, we flatten it so there is just a rgbd key and state key + envs = FlattenRGBDObservationWrapper(envs, rgb_only=True) + eval_envs = FlattenRGBDObservationWrapper(eval_envs, rgb_only=True) + if isinstance(envs.action_space, gym.spaces.Dict): + envs = FlattenActionSpaceWrapper(envs) + eval_envs = FlattenActionSpaceWrapper(eval_envs) + if args.capture_video: + eval_envs = RecordEpisode(eval_envs, output_dir=f"runs/{run_name}/videos", save_trajectory=False, video_fps=30) + envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=False, **env_kwargs) + eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=True, **env_kwargs) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + + + # ALGO Logic: Storage setup + obs = DictArray((args.num_steps, args.num_envs), envs.single_observation_space, device=device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset(seed=args.seed) + eval_obs, _ = eval_envs.reset(seed=args.seed) + next_done = torch.zeros(args.num_envs, device=device) + eps_returns = torch.zeros(args.num_envs, dtype=torch.float, device=device) + eps_lens = np.zeros(args.num_envs) + place_rew = torch.zeros(args.num_envs, device=device) + print(f"####") + print(f"args.num_iterations={args.num_iterations} args.num_envs={args.num_envs} args.num_eval_envs={args.num_eval_envs}") + print(f"args.minibatch_size={args.minibatch_size} args.batch_size={args.batch_size} args.update_epochs={args.update_epochs}") + print(f"####") + agent = Agent(envs, sample_obs=next_obs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + for iteration in range(1, args.num_iterations + 1): + print(f"Epoch: {iteration}, global_step={global_step}") + final_values = torch.zeros((args.num_steps, args.num_envs), device=device) + agent.eval() + if iteration % args.eval_freq == 1: + # evaluate + print("Evaluating") + eval_done = False + while not eval_done: + with torch.no_grad(): + eval_obs, _, eval_terminations, eval_truncations, eval_infos = eval_envs.step(agent.get_action(eval_obs, deterministic=True)) + if eval_truncations.any(): + eval_done = True + info = eval_infos["final_info"] + episodic_return = info['episode']['r'].mean().cpu().numpy() + print(f"eval_episodic_return={episodic_return}") + writer.add_scalar("charts/eval_success_rate", info["success"].float().mean().cpu().numpy(), global_step) + writer.add_scalar("charts/eval_episodic_return", episodic_return, global_step) + writer.add_scalar("charts/eval_episodic_length", info["elapsed_steps"].float().mean().cpu().numpy(), global_step) + + if args.save_model and iteration % args.eval_freq == 1: + model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.cleanrl_model" + torch.save(agent.state_dict(), model_path) + print(f"model saved to {model_path}") + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (iteration - 1.0) / args.num_iterations + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, terminations, truncations, infos = envs.step(action) + next_done = torch.logical_or(terminations, truncations).to(torch.float32) + rewards[step] = reward.view(-1) + + if "final_info" in infos: + info = infos["final_info"] + done_mask = info["_final_info"] + episodic_return = info['episode']['r'][done_mask].mean().cpu().numpy() + writer.add_scalar("charts/success_rate", info["success"][done_mask].float().mean().cpu().numpy(), global_step) + writer.add_scalar("charts/episodic_return", episodic_return, global_step) + writer.add_scalar("charts/episodic_length", info["elapsed_steps"][done_mask].float().mean().cpu().numpy(), global_step) + for k in info["final_observation"]: + info["final_observation"][k] = info["final_observation"][k][done_mask] + final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(info["final_observation"]).view(-1) + + # bootstrap value according to termination and truncation + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + next_not_done = 1.0 - next_done + nextvalues = next_value + else: + next_not_done = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + real_next_values = next_not_done * nextvalues + final_values[t] # t instead of t+1 + # next_not_done means nextvalues is computed from the correct next_obs + # if next_not_done is 1, final_values is always 0 + # if next_not_done is 0, then use final_values, which is computed according to bootstrap_at_done + if args.finite_horizon_gae: + """ + See GAE paper equation(16) line 1, we will compute the GAE based on this line only + 1 *( -V(s_t) + r_t + gamma * V(s_{t+1}) ) + lambda *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * V(s_{t+2}) ) + lambda^2 *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + ... ) + lambda^3 *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + gamma^3 * r_{t+3} + We then normalize it by the sum of the lambda^i (instead of 1-lambda) + """ + if t == args.num_steps - 1: # initialize + lam_coef_sum = 0. + reward_term_sum = 0. # the sum of the second term + value_term_sum = 0. # the sum of the third term + lam_coef_sum = lam_coef_sum * next_not_done + reward_term_sum = reward_term_sum * next_not_done + value_term_sum = value_term_sum * next_not_done + + lam_coef_sum = 1 + args.gae_lambda * lam_coef_sum + reward_term_sum = args.gae_lambda * args.gamma * reward_term_sum + lam_coef_sum * rewards[t] + value_term_sum = args.gae_lambda * args.gamma * value_term_sum + args.gamma * real_next_values + + advantages[t] = (reward_term_sum + value_term_sum) / lam_coef_sum - values[t] + else: + delta = rewards[t] + args.gamma * real_next_values - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * next_not_done * lastgaelam # Here actually we should use next_not_terminated, but we don't have lastgamlam if terminated + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,)) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + agent.train() + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None and approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}_final.cleanrl_model" + torch.save(agent.state_dict(), model_path) + print(f"model saved to {model_path}") + + envs.close() + writer.close() diff --git a/mani_skill2/utils/wrappers/flatten.py b/mani_skill2/utils/wrappers/flatten.py index 24c72bf4e..46d91092a 100644 --- a/mani_skill2/utils/wrappers/flatten.py +++ b/mani_skill2/utils/wrappers/flatten.py @@ -17,9 +17,10 @@ class FlattenRGBDObservationWrapper(gym.ObservationWrapper): Flattens the rgbd mode observations into a dictionary with two keys, "rgbd" and "state" """ - def __init__(self, env) -> None: + def __init__(self, env, rgb_only=False) -> None: self.base_env: BaseEnv = env.unwrapped super().__init__(env) + self.rgb_only = rgb_only new_obs = self.observation(self.base_env._init_raw_obs) self.base_env._update_obs_space(new_obs) @@ -29,7 +30,8 @@ def observation(self, observation: Dict): images = [] for cam_data in sensor_data.values(): images.append(cam_data["rgb"]) - images.append(cam_data["depth"]) + if not self.rgb_only: + images.append(cam_data["depth"]) images = torch.concat(images, axis=-1) # flatten the rest of the data which should just be state data observation = flatten_state_dict(observation, use_torch=True) From 6eb87e03f1b86d071496c3ac5864b7c748e3c4e7 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Fri, 1 Mar 2024 22:18:14 -0800 Subject: [PATCH 04/17] bug fixes --- examples/benchmarking/benchmark_maniskill.py | 6 ------ mani_skill2/envs/tasks/pick_cube.py | 8 +++----- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/examples/benchmarking/benchmark_maniskill.py b/examples/benchmarking/benchmark_maniskill.py index 6a1e3060a..7e2d44550 100644 --- a/examples/benchmarking/benchmark_maniskill.py +++ b/examples/benchmarking/benchmark_maniskill.py @@ -1,21 +1,15 @@ # py-spy record -f speedscope -r 1000 -o profile -- python manualtest/benchmark_gpu_sim.py # python manualtest/benchmark_orbit_sim.py --task "Isaac-Lift-Cube-Franka-v0" --num_envs 512 --headless import argparse -import time import gymnasium as gym import numpy as np -import sapien import sapien.physx import sapien.render import torch import tqdm import mani_skill2.envs -from mani_skill2.envs.scenes.tasks.planner.planner import PickSubtask -from mani_skill2.envs.scenes.tasks.sequential_task import SequentialTaskEnv -from mani_skill2.utils.scene_builder.ai2thor.variants import ArchitecTHORSceneBuilder -from mani_skill2.utils.scene_builder.replicacad.scene_builder import ReplicaCADSceneBuilder from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv from profiling import Profiler from mani_skill2.utils.visualization.misc import images_to_video, tile_images diff --git a/mani_skill2/envs/tasks/pick_cube.py b/mani_skill2/envs/tasks/pick_cube.py index 54b9df1aa..935e87dc1 100644 --- a/mani_skill2/envs/tasks/pick_cube.py +++ b/mani_skill2/envs/tasks/pick_cube.py @@ -56,14 +56,12 @@ def __init__(self, *args, robot_uids="panda", robot_init_qpos_noise=0.02, **kwar super().__init__(*args, robot_uids=robot_uids, **kwargs) def _register_sensors(self): - pose = look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) - return [ - CameraConfig("base_camera", pose.p, pose.q, 128, 128, np.pi / 2, 0.01, 10) - ] + pose = look_at(eye=[0.4, 0.4, 0.6], target=[0, 0, 0.2]) + return [CameraConfig("base_camera", pose.p, pose.q, 128, 128, 1, 0.01, 100)] def _register_human_render_cameras(self): pose = look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35]) - return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10) + return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 100) def _load_actors(self): self.table_scene = TableSceneBuilder( From 40ea163da7b7823e822055ab3f2a034e9b685898 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Sat, 2 Mar 2024 13:21:15 -0800 Subject: [PATCH 05/17] work --- examples/baselines/ppo/ppo_rgb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index ef8a4f072..01527fb6e 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -79,7 +79,7 @@ class Args: """coefficient of the value function""" max_grad_norm: float = 0.5 """the maximum norm for the gradient clipping""" - target_kl: float = 0.1 + target_kl: float = 0.2 """the target KL divergence threshold""" eval_freq: int = 25 """evaluation frequency in terms of iterations""" @@ -276,7 +276,7 @@ def get_action_and_value(self, x, action=None): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array") + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="sensors") envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs) eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) From d5a384ff6593b5770c220a1be564c135e34536a5 Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Wed, 13 Mar 2024 02:44:10 -0700 Subject: [PATCH 06/17] changes --- examples/baselines/ppo/ppo_rgb.py | 55 +++++++++++++++++++--------- mani_skill/envs/sapien_env.py | 6 +-- mani_skill/utils/wrappers/flatten.py | 4 +- 3 files changed, 43 insertions(+), 22 deletions(-) diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index 01527fb6e..2fda3644a 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -14,10 +14,10 @@ from torch.utils.tensorboard import SummaryWriter # ManiSkill specific imports -import mani_skill2.envs -from mani_skill2.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper -from mani_skill2.utils.wrappers.record import RecordEpisode -from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv +import mani_skill.envs +from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper +from mani_skill.utils.wrappers.record import RecordEpisode +from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv @dataclass class Args: @@ -57,6 +57,8 @@ class Args: """the number of parallel evaluation environments""" num_steps: int = 50 """the number of steps to run in each environment per policy rollout""" + num_eval_steps: int = 50 + """the number of steps to run in each evaluation environment during evaluation""" anneal_lr: bool = False """Toggle learning rate annealing for policy and value networks""" gamma: float = 0.8 @@ -287,13 +289,11 @@ def get_action_and_value(self, x, action=None): envs = FlattenActionSpaceWrapper(envs) eval_envs = FlattenActionSpaceWrapper(eval_envs) if args.capture_video: - eval_envs = RecordEpisode(eval_envs, output_dir=f"runs/{run_name}/videos", save_trajectory=False, video_fps=30) + eval_envs = RecordEpisode(eval_envs, output_dir=f"runs/{run_name}/videos", save_trajectory=False, max_steps_per_video=args.num_eval_steps, video_fps=30) envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=False, **env_kwargs) - eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=True, **env_kwargs) + eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=False, **env_kwargs) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" - - # ALGO Logic: Storage setup obs = DictArray((args.num_steps, args.num_envs), envs.single_observation_space, device=device) actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) @@ -325,18 +325,37 @@ def get_action_and_value(self, x, action=None): if iteration % args.eval_freq == 1: # evaluate print("Evaluating") - eval_done = False - while not eval_done: + eval_envs.reset() + returns = [] + eps_lens = [] + successes = [] + failures = [] + for _ in range(args.num_eval_steps): with torch.no_grad(): eval_obs, _, eval_terminations, eval_truncations, eval_infos = eval_envs.step(agent.get_action(eval_obs, deterministic=True)) - if eval_truncations.any(): - eval_done = True - info = eval_infos["final_info"] - episodic_return = info['episode']['r'].mean().cpu().numpy() - print(f"eval_episodic_return={episodic_return}") - writer.add_scalar("charts/eval_success_rate", info["success"].float().mean().cpu().numpy(), global_step) - writer.add_scalar("charts/eval_episodic_return", episodic_return, global_step) - writer.add_scalar("charts/eval_episodic_length", info["elapsed_steps"].float().mean().cpu().numpy(), global_step) + if "final_info" in eval_infos: + mask = eval_infos["_final_info"] + eps_lens.append(eval_infos["final_info"]["elapsed_steps"][mask].cpu().numpy()) + returns.append(eval_infos["final_info"]["episode"]["r"][mask].cpu().numpy()) + if "success" in eval_infos: + successes.append(eval_infos["final_info"]["success"][mask].cpu().numpy()) + if "fail" in eval_infos: + failures.append(eval_infos["final_info"]["fail"][mask].cpu().numpy()) + returns = np.concatenate(returns) + eps_lens = np.concatenate(eps_lens) + print(f"Evaluated {args.num_eval_steps * args.num_envs} steps resulting in {len(eps_lens)} episodes") + if len(successes) > 0: + successes = np.concatenate(successes) + writer.add_scalar("charts/eval_success_rate", successes.mean(), global_step) + print(f"eval_success_rate={successes.mean()}") + if len(failures) > 0: + failures = np.concatenate(failures) + writer.add_scalar("charts/eval_fail_rate", failures.mean(), global_step) + print(f"eval_fail_rate={failures.mean()}") + + print(f"eval_episodic_return={returns.mean()}") + writer.add_scalar("charts/eval_episodic_return", returns.mean(), global_step) + writer.add_scalar("charts/eval_episodic_length", eps_lens.mean(), global_step) if args.save_model and iteration % args.eval_freq == 1: model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.cleanrl_model" diff --git a/mani_skill/envs/sapien_env.py b/mani_skill/envs/sapien_env.py index 626f3aca9..225f512d5 100644 --- a/mani_skill/envs/sapien_env.py +++ b/mani_skill/envs/sapien_env.py @@ -253,7 +253,7 @@ def __init__( else 0 ) obs, _ = self.reset(seed=2022, options=dict(reconfigure=True)) - self._init_raw_obs = to_cpu_tensor(obs) + self._init_raw_obs = sapien_utils.to_cpu_tensor(obs) """the raw observation returned by the env.reset (a cpu torch tensor/dict of tensors). Useful for future observation wrappers to use to auto generate observation spaces""" if physx.is_gpu_enabled(): obs = sapien_utils.to_numpy(obs) @@ -281,10 +281,10 @@ def _update_obs_space(self, obs: Any): def single_observation_space(self): if self.num_envs > 1: return convert_observation_to_space( - to_numpy(self._init_raw_obs), unbatched=True + sapien_utils.to_numpy(self._init_raw_obs), unbatched=True ) else: - return convert_observation_to_space(to_numpy(self._init_raw_obs)) + return convert_observation_to_space(sapien_utils.to_numpy(self._init_raw_obs)) @cached_property def observation_space(self): diff --git a/mani_skill/utils/wrappers/flatten.py b/mani_skill/utils/wrappers/flatten.py index 499faa451..8ea009f46 100644 --- a/mani_skill/utils/wrappers/flatten.py +++ b/mani_skill/utils/wrappers/flatten.py @@ -21,7 +21,9 @@ def __init__(self, env, rgb_only=False) -> None: self.base_env: BaseEnv = env.unwrapped super().__init__(env) self.rgb_only = rgb_only - new_obs = self.observation(self.base_env._init_raw_obs) + new_obs = self.observation( + sapien_utils.to_cpu_tensor(self.base_env._init_raw_obs) + ) self.base_env._update_obs_space(new_obs) def observation(self, observation: Dict): From 7a7b31384fc39fd3342023df94a7f5db49a542f9 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Mon, 18 Mar 2024 14:35:04 -0700 Subject: [PATCH 07/17] fix target kl breaking too late --- examples/baselines/ppo/ppo.py | 3 +++ examples/baselines/ppo/ppo_rgb.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/examples/baselines/ppo/ppo.py b/examples/baselines/ppo/ppo.py index 9287cc221..f235edb2e 100644 --- a/examples/baselines/ppo/ppo.py +++ b/examples/baselines/ppo/ppo.py @@ -372,6 +372,9 @@ def clip_action(action: torch.Tensor): approx_kl = ((ratio - 1) - logratio).mean() clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + if args.target_kl is not None and approx_kl > args.target_kl: + break + mb_advantages = b_advantages[mb_inds] if args.norm_adv: mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index 2fda3644a..762c8a600 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -466,6 +466,9 @@ def get_action_and_value(self, x, action=None): approx_kl = ((ratio - 1) - logratio).mean() clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + if args.target_kl is not None and approx_kl > args.target_kl: + break + mb_advantages = b_advantages[mb_inds] if args.norm_adv: mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) From 6e9551dbc4f6919162653c0dd1a9a21279623f84 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Mon, 18 Mar 2024 16:30:55 -0700 Subject: [PATCH 08/17] results --- examples/baselines/ppo/README.md | 13 ++++++++++++- examples/baselines/ppo/ppo_rgb.py | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/baselines/ppo/README.md b/examples/baselines/ppo/README.md index 99f78118b..1f40acaa1 100644 --- a/examples/baselines/ppo/README.md +++ b/examples/baselines/ppo/README.md @@ -2,7 +2,7 @@ Code adapted from [CleanRL](https://github.com/vwxyzjn/cleanrl/) -Below is a sample of various commands you can run to train a policy to solve various tasks with PPO that are lightly tuned already. The fastest one is the PushCube-v1 task which can take less than a minute to train on the GPU. +Below is a sample of various commands you can run to train a state-based policy to solve various tasks with PPO that are lightly tuned already. The fastest one is the PushCube-v1 task which can take less than a minute to train on the GPU and the PickCube-v1 task which can take 2-5 minutes on the GPU. ```bash python ppo.py --env_id="PushCube-v1" \ @@ -22,4 +22,15 @@ python ppo.py --env_id="TwoRobotStackCube-v1" \ --total_timesteps=40_000_000 --num-steps=100 --num-eval-steps=100 ``` +Below is a sample of various commands for training a image-based policy with PPO that are lightly tuned. The fastest again is also PushCube-v1 which can take about 5 minutes and PickCube-v1 which take 40 minutes. You will need to tune the `--num_envs` argument according to how much GPU memory you have as rendering visual observations uses a lot of memory. The settings below should all take less than 15GB of GPU memory. + +```bash +python ppo_rgb.py --env_id="PushCube-v1" \ + --num_envs=512 --update_epochs=8 --num_minibatches=16 \ + --total_timesteps=5_000_000 --eval_freq=10 --num-steps=20 +python ppo_rgb.py --env_id="PickCube-v1" \ + --num_envs=256 --update_epochs=8 --num_minibatches=16 \ + --total_timesteps=10_000_000 +``` + \ No newline at end of file diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index 762c8a600..ede546be9 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -278,7 +278,7 @@ def get_action_and_value(self, x, action=None): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="sensors") + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array") envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs) eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) From 3532370a235a12057ed6007f9d70b4f45bc383ad Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Sat, 13 Apr 2024 00:20:27 -0700 Subject: [PATCH 09/17] fixes --- mani_skill/utils/common.py | 27 ++++++++++++++++++++++----- mani_skill/utils/wrappers/flatten.py | 4 +--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/mani_skill/utils/common.py b/mani_skill/utils/common.py index b7572d19c..053c0b07c 100644 --- a/mani_skill/utils/common.py +++ b/mani_skill/utils/common.py @@ -137,7 +137,7 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence], device: Device = N """ if isinstance(array, (dict)): return {k: to_tensor(v) for k, v in array.items()} - if get_backend_name() == "torch": + if physx.is_gpu_enabled(): if isinstance(array, np.ndarray): if array.dtype == np.uint16: array = array.astype(np.int32) @@ -147,12 +147,12 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence], device: Device = N elif isinstance(array, torch.Tensor): ret = array else: - ret = torch.Tensor(array) + ret = torch.tensor(array) if device is None: return ret.cuda() else: return ret.to(device) - elif get_backend_name() == "numpy": + else: if isinstance(array, np.ndarray): if array.dtype == np.uint16: array = array.astype(np.int32) @@ -164,15 +164,32 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence], device: Device = N if ret.dtype == torch.float64: ret = ret.float() elif np.iterable(array): - ret = torch.Tensor(array) + ret = torch.tensor(array) else: - ret = torch.Tensor(array) + ret = torch.tensor(array) if device is None: return ret else: return ret.to(device) +def to_cpu_tensor(array: Union[torch.Tensor, np.array, Sequence]): + """ + Maps any given sequence to a torch tensor on the CPU. + """ + if isinstance(array, (dict)): + return {k: to_tensor(v) for k, v in array.items()} + if isinstance(array, np.ndarray): + ret = torch.from_numpy(array) + if ret.dtype == torch.float64: + ret = ret.float() + return ret + elif isinstance(array, torch.Tensor): + return array.cpu() + else: + return torch.tensor(array).cpu() + + # TODO (stao): Clean up this code def flatten_state_dict( state_dict: dict, use_torch=False, device: Device = None diff --git a/mani_skill/utils/wrappers/flatten.py b/mani_skill/utils/wrappers/flatten.py index 5b0904871..ba982e10a 100644 --- a/mani_skill/utils/wrappers/flatten.py +++ b/mani_skill/utils/wrappers/flatten.py @@ -20,9 +20,7 @@ def __init__(self, env, rgb_only=False) -> None: self.base_env: BaseEnv = env.unwrapped super().__init__(env) self.rgb_only = rgb_only - new_obs = self.observation( - sapien_utils.to_cpu_tensor(self.base_env._init_raw_obs) - ) + new_obs = self.observation(common.to_cpu_tensor(self.base_env._init_raw_obs)) self.base_env._update_obs_space(new_obs) def observation(self, observation: Dict): From 662cd5264fe4de2c07d31b3ad25f6d3d77dd9a7a Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Sat, 13 Apr 2024 00:21:26 -0700 Subject: [PATCH 10/17] Update flatten.py --- mani_skill/utils/wrappers/flatten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mani_skill/utils/wrappers/flatten.py b/mani_skill/utils/wrappers/flatten.py index ba982e10a..8cb324836 100644 --- a/mani_skill/utils/wrappers/flatten.py +++ b/mani_skill/utils/wrappers/flatten.py @@ -33,7 +33,7 @@ def observation(self, observation: Dict): images.append(cam_data["depth"]) images = torch.concat(images, axis=-1) # flatten the rest of the data which should just be state data - observation = flatten_state_dict(observation, use_torch=True) + observation = common.flatten_state_dict(observation, use_torch=True) return dict(state=observation, rgbd=images) From 75d3a5ee556a75fa1533fb8ce06ed3fb7b551440 Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Sat, 13 Apr 2024 00:22:34 -0700 Subject: [PATCH 11/17] work --- examples/baselines/ppo/ppo_rgb.py | 12 ++++++------ mani_skill/utils/wrappers/flatten.py | 5 ++++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index ede546be9..2e8b383eb 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -153,8 +153,8 @@ def __init__(self, sample_obs): self.out_features = 0 feature_size = 256 - in_channels=sample_obs["rgbd"].shape[-1] - image_size=(sample_obs["rgbd"].shape[1], sample_obs["rgbd"].shape[2]) + in_channels=sample_obs["rgb"].shape[-1] + image_size=(sample_obs["rgb"].shape[1], sample_obs["rgb"].shape[2]) state_size=sample_obs["state"].shape[-1] # here we use a NatureCNN architecture to process images, but any architecture is permissble here @@ -180,9 +180,9 @@ def __init__(self, sample_obs): # to easily figure out the dimensions after flattening, we pass a test tensor with torch.no_grad(): - n_flatten = cnn(sample_obs["rgbd"].float().permute(0,3,1,2).cpu()).shape[1] + n_flatten = cnn(sample_obs["rgb"].float().permute(0,3,1,2).cpu()).shape[1] fc = nn.Sequential(nn.Linear(n_flatten, feature_size), nn.ReLU()) - extractors["rgbd"] = nn.Sequential(cnn, fc) + extractors["rgb"] = nn.Sequential(cnn, fc) self.out_features += feature_size # for state data we simply pass it through a single linear layer @@ -196,7 +196,7 @@ def forward(self, observations) -> torch.Tensor: # self.extractors contain nn.Modules that do all the processing. for key, extractor in self.extractors.items(): obs = observations[key] - if key == "rgbd": + if key == "rgb": obs = obs.float().permute(0,3,1,2) obs = obs / 255 encoded_tensor_list.append(extractor(obs)) @@ -278,7 +278,7 @@ def get_action_and_value(self, x, action=None): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array") + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_ee_delta_pose", render_mode="rgb_array") envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs) eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) diff --git a/mani_skill/utils/wrappers/flatten.py b/mani_skill/utils/wrappers/flatten.py index 8cb324836..5fe902955 100644 --- a/mani_skill/utils/wrappers/flatten.py +++ b/mani_skill/utils/wrappers/flatten.py @@ -34,7 +34,10 @@ def observation(self, observation: Dict): images = torch.concat(images, axis=-1) # flatten the rest of the data which should just be state data observation = common.flatten_state_dict(observation, use_torch=True) - return dict(state=observation, rgbd=images) + if self.rgb_only: + return dict(state=observation, rgb=images) + else: + return dict(state=observation, rgbd=images) class FlattenObservationWrapper(gym.ObservationWrapper): From 9dc821c0de255dde368e3774b6ce83b7c17ea50a Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Wed, 24 Apr 2024 14:24:17 -0700 Subject: [PATCH 12/17] record better metrics, bug fixes --- examples/baselines/ppo/README.md | 32 +++++++++++++++------ examples/baselines/ppo/ppo.py | 10 +++++-- examples/baselines/ppo/ppo_rgb.py | 13 +++++---- mani_skill/agents/controllers/pd_ee_pose.py | 7 +++-- mani_skill/envs/tasks/push_cube.py | 2 +- 5 files changed, 45 insertions(+), 19 deletions(-) diff --git a/examples/baselines/ppo/README.md b/examples/baselines/ppo/README.md index 20ab342d1..a6f4abf1a 100644 --- a/examples/baselines/ppo/README.md +++ b/examples/baselines/ppo/README.md @@ -1,9 +1,15 @@ # Proximal Policy Optimization (PPO) -Code for running the PPO RL algorithm is adapted from [CleanRL](https://github.com/vwxyzjn/cleanrl/). It is written to be a single-file and easy to follow/read +Code for running the PPO RL algorithm is adapted from [CleanRL](https://github.com/vwxyzjn/cleanrl/). It is written to be single-file and easy to follow/read, and supports state-based RL and visual-based RL code. + + +## State Based RL Below is a sample of various commands you can run to train a state-based policy to solve various tasks with PPO that are lightly tuned already. The fastest one is the PushCube-v1 task which can take less than a minute to train on the GPU and the PickCube-v1 task which can take 2-5 minutes on the GPU. +The PPO baseline is not guaranteed to work for tasks not tested below as some tasks do not have dense rewards yet or well tuned ones, or simply are too hard with standard PPO (or our team has not had time to verify results yet) + + ```bash python ppo.py --env_id="PushCube-v1" \ --num_envs=2048 --update_epochs=8 --num_minibatches=32 \ @@ -34,7 +40,7 @@ python ppo.py --env_id="PickSingleYCB-v1" \ --total_timesteps=25_000_000 python ppo.py --env_id="PegInsertionSide-v1" \ --num_envs=1024 --update_epochs=8 --num_minibatches=32 \ - --total_timesteps=150_000_000 --num-steps=100 --num-eval-steps=100 + --total_timesteps=250_000_000 --num-steps=100 --num-eval-steps=100 python ppo.py --env_id="TwoRobotStackCube-v1" \ --num_envs=1024 --update_epochs=8 --num_minibatches=32 \ --total_timesteps=40_000_000 --num-steps=100 --num-eval-steps=100 @@ -67,19 +73,29 @@ python ppo.py --env_id="UnitreeH1Stand-v1" \ python ppo.py --env_id="OpenCabinetDrawer-v1" \ --num_envs=1024 --update_epochs=8 --num_minibatches=32 \ - --total_timesteps=10_000_000 --num-steps=100 --num-eval-steps=100 - --gamma=0.9 - + --total_timesteps=10_000_000 --num-steps=100 --num-eval-steps=100 ``` +## Visual Based RL + +Below is a sample of various commands for training a image-based policy with PPO that are lightly tuned. The fastest again is also PushCube-v1 which can take about 1-5 minutes and PickCube-v1 which takes 30-60 minutes. You will need to tune the `--num_envs` argument according to how much GPU memory you have as rendering visual observations uses a lot of memory. The settings below should all take less than 15GB of GPU memory. Note that while if you have enough memory you can easily increase the number of environments, this does not necessarily mean wall-time or sample efficiency improve. + +The visual PPO baseline is not guaranteed to work for tasks not tested below as some tasks do not have dense rewards yet or well tuned ones, or simply are too hard with standard PPO (or our team has not had time to verify results yet) + -Below is a sample of various commands for training a image-based policy with PPO that are lightly tuned. The fastest again is also PushCube-v1 which can take about 5 minutes and PickCube-v1 which take 40 minutes. You will need to tune the `--num_envs` argument according to how much GPU memory you have as rendering visual observations uses a lot of memory. The settings below should all take less than 15GB of GPU memory. ```bash python ppo_rgb.py --env_id="PushCube-v1" \ --num_envs=512 --update_epochs=8 --num_minibatches=16 \ - --total_timesteps=5_000_000 --eval_freq=10 --num-steps=20 + --total_timesteps=1_000_000 --eval_freq=10 --num-steps=20 python ppo_rgb.py --env_id="PickCube-v1" \ --num_envs=256 --update_epochs=8 --num_minibatches=16 \ --total_timesteps=10_000_000 -``` \ No newline at end of file +python ppo_rgb.py --env_id="OpenCabinetDrawer-v1" \ + --num_envs=256 --update_epochs=8 --num_minibatches=32 \ + --total_timesteps=10_000_000 --num-steps=100 --num-eval-steps=100 +``` + +## Some Notes + +- The code currently does not have the best way to evaluate the agents in that during GPU simulation, all assets are frozen per parallel environment (changing them slows training down). Thus when doing evaluation, even though we evaluate on multiple (8 is default) environments at once, they will always feature the same set of geometry. This only affects tasks where there is geometry variation (e.g. PickClutterYCB, OpenCabinetDrawer). You can make it more accurate by increasing the number of evaluation environments. Our team is discussing still what is the best way to evaluate trained agents properly without hindering performance. \ No newline at end of file diff --git a/examples/baselines/ppo/ppo.py b/examples/baselines/ppo/ppo.py index 574fd9ada..529e0b20c 100644 --- a/examples/baselines/ppo/ppo.py +++ b/examples/baselines/ppo/ppo.py @@ -292,6 +292,7 @@ def clip_action(action: torch.Tensor): lrnow = frac * args.learning_rate optimizer.param_groups[0]["lr"] = lrnow + rollout_time = time.time() for step in range(0, args.num_steps): global_step += args.num_envs obs[step] = next_obs @@ -321,7 +322,7 @@ def clip_action(action: torch.Tensor): writer.add_scalar("charts/episodic_length", final_info["elapsed_steps"][done_mask].cpu().numpy().mean(), global_step) final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(final_info["final_observation"][done_mask]).view(-1) - + rollout_time = time.time() - rollout_time # bootstrap value according to termination and truncation with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) @@ -377,6 +378,7 @@ def clip_action(action: torch.Tensor): agent.train() b_inds = np.arange(args.batch_size) clipfracs = [] + update_time = time.time() for epoch in range(args.update_epochs): np.random.shuffle(b_inds) for start in range(0, args.batch_size, args.minibatch_size): @@ -430,12 +432,12 @@ def clip_action(action: torch.Tensor): if args.target_kl is not None and approx_kl > args.target_kl: break + update_time = time.time() - update_time y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() var_y = np.var(y_true) explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y - # TRY NOT TO MODIFY: record rewards for plotting purposes writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) @@ -446,7 +448,9 @@ def clip_action(action: torch.Tensor): writer.add_scalar("losses/explained_variance", explained_var, global_step) print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - + writer.add_scalar("charts/update_time", update_time, global_step) + writer.add_scalar("charts/rollout_time", rollout_time, global_step) + writer.add_scalar("charts/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step) if not args.evaluate: if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}_final.cleanrl_model" diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index 2e8b383eb..f18836dfa 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -278,7 +278,7 @@ def get_action_and_value(self, x, action=None): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - env_kwargs = dict(obs_mode="rgbd", control_mode="pd_ee_delta_pose", render_mode="rgb_array") + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array") envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs) eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) @@ -366,7 +366,7 @@ def get_action_and_value(self, x, action=None): frac = 1.0 - (iteration - 1.0) / args.num_iterations lrnow = frac * args.learning_rate optimizer.param_groups[0]["lr"] = lrnow - + rollout_time = time.time() for step in range(0, args.num_steps): global_step += args.num_envs obs[step] = next_obs @@ -394,7 +394,7 @@ def get_action_and_value(self, x, action=None): for k in info["final_observation"]: info["final_observation"][k] = info["final_observation"][k][done_mask] final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(info["final_observation"]).view(-1) - + rollout_time = time.time() - rollout_time # bootstrap value according to termination and truncation with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) @@ -450,6 +450,7 @@ def get_action_and_value(self, x, action=None): agent.train() b_inds = np.arange(args.batch_size) clipfracs = [] + update_time = time.time() for epoch in range(args.update_epochs): np.random.shuffle(b_inds) for start in range(0, args.batch_size, args.minibatch_size): @@ -503,12 +504,11 @@ def get_action_and_value(self, x, action=None): if args.target_kl is not None and approx_kl > args.target_kl: break - + update_time = time.time() - update_time y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() var_y = np.var(y_true) explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y - # TRY NOT TO MODIFY: record rewards for plotting purposes writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("losses/value_loss", v_loss.item(), global_step) writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) @@ -519,6 +519,9 @@ def get_action_and_value(self, x, action=None): writer.add_scalar("losses/explained_variance", explained_var, global_step) print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar("charts/update_time", update_time, global_step) + writer.add_scalar("charts/rollout_time", rollout_time, global_step) + writer.add_scalar("charts/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}_final.cleanrl_model" torch.save(agent.state_dict(), model_path) diff --git a/mani_skill/agents/controllers/pd_ee_pose.py b/mani_skill/agents/controllers/pd_ee_pose.py index 0ec885322..12cbe9fa7 100644 --- a/mani_skill/agents/controllers/pd_ee_pose.py +++ b/mani_skill/agents/controllers/pd_ee_pose.py @@ -20,8 +20,9 @@ from .pd_joint_pos import PDJointPosController -# NOTE(jigu): not necessary to inherit, just for convenience class PDEEPosController(PDJointPosController): + """The PD EE Position controller. NOTE that on the GPU it is assumed the controlled robot is not a merged articulation and is the same across every sub-scene""" + config: "PDEEPosControllerConfig" _target_pose = None @@ -55,9 +56,11 @@ def _initialize_joints(self): cur_link = cur_link.joint.parent_link active_ancestor_joints = active_ancestor_joints[::-1] self.active_ancestor_joints = active_ancestor_joints + # initially self.active_joint_indices references active joints that are controlled. + # we also make the assumption that the active index is the same across all parallel managed joints self.active_ancestor_joint_idxs = [ - x.active_index for x in self.active_ancestor_joints + (x.active_index[0]).cpu().item() for x in self.active_ancestor_joints ] controlled_joints_idx_in_qmask = [ self.active_ancestor_joint_idxs.index(idx) diff --git a/mani_skill/envs/tasks/push_cube.py b/mani_skill/envs/tasks/push_cube.py index 159ad174d..27f7453f1 100644 --- a/mani_skill/envs/tasks/push_cube.py +++ b/mani_skill/envs/tasks/push_cube.py @@ -190,12 +190,12 @@ def _get_obs_extra(self, info: Dict): # grippers of the robot obs = dict( tcp_pose=self.agent.tcp.pose.raw_pose, - goal_pos=self.goal_region.pose.p, ) if self._obs_mode in ["state", "state_dict"]: # if the observation mode is state/state_dict, we provide ground truth information about where the cube is. # for visual observation modes one should rely on the sensed visual data to determine where the cube is obs.update( + goal_pos=self.goal_region.pose.p, obj_pose=self.obj.pose.raw_pose, ) return obs From 2675240dd47a8beb9a684fa3683e55709512b65c Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Wed, 24 Apr 2024 15:15:38 -0700 Subject: [PATCH 13/17] Update README.md --- examples/baselines/ppo/README.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/baselines/ppo/README.md b/examples/baselines/ppo/README.md index a6f4abf1a..9eafdf2ef 100644 --- a/examples/baselines/ppo/README.md +++ b/examples/baselines/ppo/README.md @@ -88,12 +88,9 @@ The visual PPO baseline is not guaranteed to work for tasks not tested below as python ppo_rgb.py --env_id="PushCube-v1" \ --num_envs=512 --update_epochs=8 --num_minibatches=16 \ --total_timesteps=1_000_000 --eval_freq=10 --num-steps=20 -python ppo_rgb.py --env_id="PickCube-v1" \ - --num_envs=256 --update_epochs=8 --num_minibatches=16 \ - --total_timesteps=10_000_000 python ppo_rgb.py --env_id="OpenCabinetDrawer-v1" \ - --num_envs=256 --update_epochs=8 --num_minibatches=32 \ - --total_timesteps=10_000_000 --num-steps=100 --num-eval-steps=100 + --num_envs=256 --update_epochs=8 --num_minibatches=16 \ + --total_timesteps=100_000_000 --num-steps=100 --num-eval-steps=100 ``` ## Some Notes From 0a5174480cea58c8bd0083a265ec6d3d3a64905e Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Wed, 24 Apr 2024 16:55:45 -0700 Subject: [PATCH 14/17] eval visual rl policy --- examples/baselines/ppo/README.md | 13 +++- examples/baselines/ppo/ppo_rgb.py | 76 ++++++++++++------- mani_skill/utils/common.py | 13 +--- .../utils/structs/articulation_joint.py | 4 +- mani_skill/utils/wrappers/record.py | 5 +- 5 files changed, 69 insertions(+), 42 deletions(-) diff --git a/examples/baselines/ppo/README.md b/examples/baselines/ppo/README.md index 9eafdf2ef..837fe3e32 100644 --- a/examples/baselines/ppo/README.md +++ b/examples/baselines/ppo/README.md @@ -19,7 +19,8 @@ python ppo.py --env_id="PushCube-v1" \ To evaluate, you can run ```bash python ppo.py --env_id="PickCube-v1" \ - --evaluate --num_eval_envs=1 --checkpoint=runs/PickCube-v1__ppo__1__1710225023/ppo_101.cleanrl_model + --evaluate --checkpoint=path/to/model.cleanrl_model \ + --num_eval_envs=1 ``` Note that with `--evaluate`, trajectories are saved from a GPU simulation. In order to support replaying these trajectories correctly with the `maniskill.trajectory.replay_trajectory` tool, the number of evaluation environments must be fixed to `1`. This is necessary in order to ensure reproducibility for tasks that have randomizations on geometry (e.g. PickSingleYCB). @@ -93,6 +94,16 @@ python ppo_rgb.py --env_id="OpenCabinetDrawer-v1" \ --total_timesteps=100_000_000 --num-steps=100 --num-eval-steps=100 ``` +To evaluate a trained policy you can run + +```bash +python ppo_rgb.py --env_id="OpenCabinetDrawer-v1" \ + --evaluate --checkpoint=path/to/model.cleanrl_model \ + --num_eval_envs=1 --num-eval-steps=1000 +``` + +and it will save videos to the `path/to/test_videos`. + ## Some Notes - The code currently does not have the best way to evaluate the agents in that during GPU simulation, all assets are frozen per parallel environment (changing them slows training down). Thus when doing evaluation, even though we evaluate on multiple (8 is default) environments at once, they will always feature the same set of geometry. This only affects tasks where there is geometry variation (e.g. PickClutterYCB, OpenCabinetDrawer). You can make it more accurate by increasing the number of evaluation environments. Our team is discussing still what is the best way to evaluate trained agents properly without hindering performance. \ No newline at end of file diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index f18836dfa..d38492a5c 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -3,6 +3,7 @@ import random import time from dataclasses import dataclass +from typing import Optional import gymnasium as gym import numpy as np @@ -43,6 +44,10 @@ class Args: """whether to upload the saved model to huggingface""" hf_entity: str = "" """the user or org name of the model repository from the Hugging Face Hub""" + evaluate: bool = False + """if toggled, only runs evaluation with the given model checkpoint and saves the evaluation trajectories""" + checkpoint: str = None + """path to a pretrained checkpoint file to start evaluation/training from""" # Algorithm specific arguments env_id: str = "PickCube-v1" @@ -85,6 +90,8 @@ class Args: """the target KL divergence threshold""" eval_freq: int = 25 """evaluation frequency in terms of iterations""" + save_train_video_freq: Optional[int] = None + """frequency to save training videos in terms of iterations""" finite_horizon_gae: bool = True # to be filled in runtime @@ -251,23 +258,28 @@ def get_action_and_value(self, x, action=None): args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" - if args.track: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=vars(args), - name=run_name, - monitor_gym=True, - save_code=True, + writer = None + if not args.evaluate: + print("Running training") + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) - writer = SummaryWriter(f"runs/{run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) + else: + print("Running evaluation") # TRY NOT TO MODIFY: seeding random.seed(args.seed) @@ -278,8 +290,8 @@ def get_action_and_value(self, x, action=None): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array") - envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs) + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_backend="gpu") + envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, **env_kwargs) eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) # rgbd obs mode returns a dict of data, we flatten it so there is just a rgbd key and state key @@ -289,7 +301,14 @@ def get_action_and_value(self, x, action=None): envs = FlattenActionSpaceWrapper(envs) eval_envs = FlattenActionSpaceWrapper(eval_envs) if args.capture_video: - eval_envs = RecordEpisode(eval_envs, output_dir=f"runs/{run_name}/videos", save_trajectory=False, max_steps_per_video=args.num_eval_steps, video_fps=30) + eval_output_dir = f"runs/{run_name}/videos" + if args.evaluate: + eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos" + print(f"Saving eval videos to {eval_output_dir}") + if args.save_train_video_freq is not None: + save_video_trigger = lambda x : (x // args.num_steps) % args.save_train_video_freq == 0 + envs = RecordEpisode(envs, output_dir=f"runs/{run_name}/train_videos", save_trajectory=False, save_video_trigger=save_video_trigger, max_steps_per_video=args.num_steps, video_fps=30) + eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=args.evaluate, max_steps_per_video=args.num_eval_steps, video_fps=30) envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=False, **env_kwargs) eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=False, **env_kwargs) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" @@ -318,6 +337,9 @@ def get_action_and_value(self, x, action=None): agent = Agent(envs, sample_obs=next_obs).to(device) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + if args.checkpoint: + agent.load_state_dict(torch.load(args.checkpoint)) + for iteration in range(1, args.num_iterations + 1): print(f"Epoch: {iteration}, global_step={global_step}") final_values = torch.zeros((args.num_steps, args.num_envs), device=device) @@ -346,17 +368,19 @@ def get_action_and_value(self, x, action=None): print(f"Evaluated {args.num_eval_steps * args.num_envs} steps resulting in {len(eps_lens)} episodes") if len(successes) > 0: successes = np.concatenate(successes) - writer.add_scalar("charts/eval_success_rate", successes.mean(), global_step) + if writer is not None: writer.add_scalar("charts/eval_success_rate", successes.mean(), global_step) print(f"eval_success_rate={successes.mean()}") if len(failures) > 0: failures = np.concatenate(failures) - writer.add_scalar("charts/eval_fail_rate", failures.mean(), global_step) + if writer is not None: writer.add_scalar("charts/eval_fail_rate", failures.mean(), global_step) print(f"eval_fail_rate={failures.mean()}") print(f"eval_episodic_return={returns.mean()}") - writer.add_scalar("charts/eval_episodic_return", returns.mean(), global_step) - writer.add_scalar("charts/eval_episodic_length", eps_lens.mean(), global_step) - + if writer is not None: + writer.add_scalar("charts/eval_episodic_return", returns.mean(), global_step) + writer.add_scalar("charts/eval_episodic_length", eps_lens.mean(), global_step) + if args.evaluate: + break if args.save_model and iteration % args.eval_freq == 1: model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.cleanrl_model" torch.save(agent.state_dict(), model_path) @@ -522,10 +546,10 @@ def get_action_and_value(self, x, action=None): writer.add_scalar("charts/update_time", update_time, global_step) writer.add_scalar("charts/rollout_time", rollout_time, global_step) writer.add_scalar("charts/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step) - if args.save_model: + if args.save_model and not args.evaluate: model_path = f"runs/{run_name}/{args.exp_name}_final.cleanrl_model" torch.save(agent.state_dict(), model_path) print(f"model saved to {model_path}") envs.close() - writer.close() + if writer is not None: writer.close() diff --git a/mani_skill/utils/common.py b/mani_skill/utils/common.py index 6237409d0..ee1488b43 100644 --- a/mani_skill/utils/common.py +++ b/mani_skill/utils/common.py @@ -217,8 +217,6 @@ def flatten_state_dict( state = flatten_state_dict(value, use_torch=use_torch) if state.size == 0: state = None - if use_torch: - state = to_tensor(state) elif isinstance(value, (tuple, list)): state = None if len(value) == 0 else value if use_torch: @@ -241,15 +239,10 @@ def flatten_state_dict( if use_torch: state = to_tensor(state) + elif isinstance(value, torch.Tensor): + state = value else: - is_torch_tensor = False - if isinstance(value, torch.Tensor): - state = value - if len(state.shape) == 1: - state = state[:, None] - is_torch_tensor = True - if not is_torch_tensor: - raise TypeError("Unsupported type: {}".format(type(value))) + raise TypeError("Unsupported type: {}".format(type(value))) if state is not None: states.append(state) diff --git a/mani_skill/utils/structs/articulation_joint.py b/mani_skill/utils/structs/articulation_joint.py index 67dbaae18..a2206ce95 100644 --- a/mani_skill/utils/structs/articulation_joint.py +++ b/mani_skill/utils/structs/articulation_joint.py @@ -97,9 +97,7 @@ def qpos(self): self._data_index, self.active_index ] else: - return torch.tensor( - [[self._physx_articulations[0].qpos[self.active_index]]] - ) + return torch.tensor([self._physx_articulations[0].qpos[self.active_index]]) @property def qvel(self): diff --git a/mani_skill/utils/wrappers/record.py b/mani_skill/utils/wrappers/record.py index fc0455894..11af5383d 100644 --- a/mani_skill/utils/wrappers/record.py +++ b/mani_skill/utils/wrappers/record.py @@ -372,6 +372,8 @@ def recursive_replace(x, y): for k in x.keys(): recursive_replace(x[k], y[k]) + # TODO (stao): how do we store states from GPU sim of tasks with objects not in every sub-scene? + # Maybe we shouldn't? recursive_replace(self._trajectory_buffer.state, first_step.state) recursive_replace( self._trajectory_buffer.observation, first_step.observation @@ -484,7 +486,7 @@ def step(self, action): def flush_trajectory( self, verbose=False, - ignore_empty_transition=False, + ignore_empty_transition=True, env_idxs_to_flush=None, ): flush_count = 0 @@ -598,7 +600,6 @@ def recursive_add_to_h5py(group: h5py.Group, data: dict, key): fail=self._trajectory_buffer.success[end_ptr - 1, env_idx] ) recursive_add_to_h5py(group, self._trajectory_buffer.state, "env_states") - if self.record_reward: group.create_dataset( "rewards", From dd83a62996fcb0a52b8248700768c4711f7d46e2 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Wed, 24 Apr 2024 17:09:12 -0700 Subject: [PATCH 15/17] finalize code --- examples/baselines/ppo/README.md | 24 ++++++++++++++++++------ examples/baselines/ppo/ppo.py | 4 ++-- examples/baselines/ppo/ppo_rgb.py | 4 ++-- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/examples/baselines/ppo/README.md b/examples/baselines/ppo/README.md index 837fe3e32..9fd62a204 100644 --- a/examples/baselines/ppo/README.md +++ b/examples/baselines/ppo/README.md @@ -13,17 +13,17 @@ The PPO baseline is not guaranteed to work for tasks not tested below as some ta ```bash python ppo.py --env_id="PushCube-v1" \ --num_envs=2048 --update_epochs=8 --num_minibatches=32 \ - --total_timesteps=5_000_000 --eval_freq=10 --num-steps=20 + --total_timesteps=2_000_000 --eval_freq=10 --num-steps=20 ``` To evaluate, you can run ```bash -python ppo.py --env_id="PickCube-v1" \ - --evaluate --checkpoint=path/to/model.cleanrl_model \ - --num_eval_envs=1 +python ppo.py --env_id="PushCube-v1" \ + --evaluate --checkpoint=path/to/model.pt \ + --num_eval_envs=1 --num-eval-steps=1000 ``` -Note that with `--evaluate`, trajectories are saved from a GPU simulation. In order to support replaying these trajectories correctly with the `maniskill.trajectory.replay_trajectory` tool, the number of evaluation environments must be fixed to `1`. This is necessary in order to ensure reproducibility for tasks that have randomizations on geometry (e.g. PickSingleYCB). +Note that with `--evaluate`, trajectories are saved from a GPU simulation. In order to support replaying these trajectories correctly with the `maniskill.trajectory.replay_trajectory` tool for some task, the number of evaluation environments must be fixed to `1`. This is necessary in order to ensure reproducibility for tasks that have randomizations on geometry (e.g. PickSingleYCB). Other tasks without geometrical randomization like PushCube are fine and you can increase the number of evaluation environments. Below is a full list of various commands you can run to train a policy to solve various tasks with PPO that are lightly tuned already. The fastest one is the PushCube-v1 task which can take less than a minute to train on the GPU. @@ -98,12 +98,24 @@ To evaluate a trained policy you can run ```bash python ppo_rgb.py --env_id="OpenCabinetDrawer-v1" \ - --evaluate --checkpoint=path/to/model.cleanrl_model \ + --evaluate --checkpoint=path/to/model.pt \ --num_eval_envs=1 --num-eval-steps=1000 ``` and it will save videos to the `path/to/test_videos`. +## Replaying Evaluation Trajectories + +It might be useful to get some nicer looking videos. A simple way to do that is to first use the evaluation scripts provided above. It will then save a .h5 and .json file with a name equal to the date and time that you can then replay with different settings as so + +```bash +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path=path/to/trajectory.h5 --use-env-states --shader="rt-fast" \ + --save-video --allow-failure +``` + +This will use environment states to replay trajectories, turn on the ray-tracer (There is also "rt" which is higher quality but slower), and save all videos including failed trajectories. + ## Some Notes - The code currently does not have the best way to evaluate the agents in that during GPU simulation, all assets are frozen per parallel environment (changing them slows training down). Thus when doing evaluation, even though we evaluate on multiple (8 is default) environments at once, they will always feature the same set of geometry. This only affects tasks where there is geometry variation (e.g. PickClutterYCB, OpenCabinetDrawer). You can make it more accurate by increasing the number of evaluation environments. Our team is discussing still what is the best way to evaluate trained agents properly without hindering performance. \ No newline at end of file diff --git a/examples/baselines/ppo/ppo.py b/examples/baselines/ppo/ppo.py index 529e0b20c..396cb1408 100644 --- a/examples/baselines/ppo/ppo.py +++ b/examples/baselines/ppo/ppo.py @@ -200,7 +200,7 @@ def get_action_and_value(self, x, action=None): if args.capture_video: eval_output_dir = f"runs/{run_name}/videos" if args.evaluate: - eval_output_dir = f"videos" + eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos" print(f"Saving eval videos to {eval_output_dir}") if args.save_train_video_freq is not None: save_video_trigger = lambda x : (x // args.num_steps) % args.save_train_video_freq == 0 @@ -283,7 +283,7 @@ def clip_action(action: torch.Tensor): if args.evaluate: break if args.save_model and iteration % args.eval_freq == 1: - model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.cleanrl_model" + model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.pt" torch.save(agent.state_dict(), model_path) print(f"model saved to {model_path}") # Annealing the rate if instructed to do so. diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index d38492a5c..f3770dcb4 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -382,7 +382,7 @@ def get_action_and_value(self, x, action=None): if args.evaluate: break if args.save_model and iteration % args.eval_freq == 1: - model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.cleanrl_model" + model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.pt" torch.save(agent.state_dict(), model_path) print(f"model saved to {model_path}") # Annealing the rate if instructed to do so. @@ -547,7 +547,7 @@ def get_action_and_value(self, x, action=None): writer.add_scalar("charts/rollout_time", rollout_time, global_step) writer.add_scalar("charts/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step) if args.save_model and not args.evaluate: - model_path = f"runs/{run_name}/{args.exp_name}_final.cleanrl_model" + model_path = f"runs/{run_name}/{args.exp_name}_final.pt" torch.save(agent.state_dict(), model_path) print(f"model saved to {model_path}") From e3ccfb77006cfbc33345e9b79900619edf5ae1d9 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Wed, 24 Apr 2024 17:23:54 -0700 Subject: [PATCH 16/17] fixes --- examples/baselines/ppo/ppo_rgb.py | 2 +- mani_skill/envs/tasks/open_cabinet_drawer.py | 2 +- mani_skill/utils/common.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index f3770dcb4..50042842d 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -290,7 +290,7 @@ def get_action_and_value(self, x, action=None): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_backend="gpu") + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="sensors", sim_backend="gpu") envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, **env_kwargs) eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) diff --git a/mani_skill/envs/tasks/open_cabinet_drawer.py b/mani_skill/envs/tasks/open_cabinet_drawer.py index 4dc8d614a..3752ab30e 100644 --- a/mani_skill/envs/tasks/open_cabinet_drawer.py +++ b/mani_skill/envs/tasks/open_cabinet_drawer.py @@ -157,7 +157,7 @@ def _load_cabinets(self, joint_types: List[str]): body_type="kinematic", add_collision=False, ) - self._hidden_objects.append(self.handle_link_goal) + # self._hidden_objects.append(self.handle_link_goal) def _after_reconfigure(self, options): # To spawn cabinets in the right place, we need to change their z position such that diff --git a/mani_skill/utils/common.py b/mani_skill/utils/common.py index ee1488b43..19b5677f0 100644 --- a/mani_skill/utils/common.py +++ b/mani_skill/utils/common.py @@ -164,9 +164,9 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence], device: Device = N if ret.dtype == torch.float64: ret = ret.float() elif np.iterable(array): - ret = torch.tensor(array) + ret = torch.Tensor(array) else: - ret = torch.tensor(array) + ret = torch.Tensor(array) if device is None: return ret else: @@ -217,6 +217,7 @@ def flatten_state_dict( state = flatten_state_dict(value, use_torch=use_torch) if state.size == 0: state = None + state = to_tensor(state) elif isinstance(value, (tuple, list)): state = None if len(value) == 0 else value if use_torch: @@ -241,6 +242,8 @@ def flatten_state_dict( elif isinstance(value, torch.Tensor): state = value + if len(state.shape) == 1: + state = state[:, None] else: raise TypeError("Unsupported type: {}".format(type(value))) if state is not None: From f3cab93db9b04e5ecf88ba26bdf4e9fc7fe2487f Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Wed, 24 Apr 2024 17:26:31 -0700 Subject: [PATCH 17/17] work --- examples/baselines/ppo/README.md | 6 +++--- examples/baselines/ppo/ppo_rgb.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/baselines/ppo/README.md b/examples/baselines/ppo/README.md index 9fd62a204..1b2313425 100644 --- a/examples/baselines/ppo/README.md +++ b/examples/baselines/ppo/README.md @@ -98,8 +98,8 @@ To evaluate a trained policy you can run ```bash python ppo_rgb.py --env_id="OpenCabinetDrawer-v1" \ - --evaluate --checkpoint=path/to/model.pt \ - --num_eval_envs=1 --num-eval-steps=1000 + --evaluate --checkpoint=path/to/model.pt \ + --num_eval_envs=1 --num-eval-steps=1000 ``` and it will save videos to the `path/to/test_videos`. @@ -111,7 +111,7 @@ It might be useful to get some nicer looking videos. A simple way to do that is ```bash python -m mani_skill.trajectory.replay_trajectory \ --traj-path=path/to/trajectory.h5 --use-env-states --shader="rt-fast" \ - --save-video --allow-failure + --save-video --allow-failure -o "none" ``` This will use environment states to replay trajectories, turn on the ray-tracer (There is also "rt" which is higher quality but slower), and save all videos including failed trajectories. diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py index 50042842d..f3770dcb4 100644 --- a/examples/baselines/ppo/ppo_rgb.py +++ b/examples/baselines/ppo/ppo_rgb.py @@ -290,7 +290,7 @@ def get_action_and_value(self, x, action=None): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="sensors", sim_backend="gpu") + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_backend="gpu") envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, **env_kwargs) eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs)