Skip to content

Td3 ddpg action bound fix #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
44b3fe1
Prototype the actor's output action scaling to the action space of th…
dosssman Jun 20, 2022
e9fbacf
TD3 and DDPG's action scale and bias move to GPU if needs be
dosssman Jun 20, 2022
ad4dc49
Fixed formatting
dosssman Jun 21, 2022
d6b96e9
pre-commit fixed formating
dosssman Jun 21, 2022
ba93aaf
action_scale and action_bias in TD3 / DDPG and SAC finally using regi…
dosssman Jun 22, 2022
0e20bac
Fixed the = self.register_buffer code artifact
dosssman Jun 22, 2022
9985d63
TD3 adjusted the exploration noise for the policy during the rollout …
dosssman Jun 25, 2022
ecb2130
Removed obsolete next obs actions clamping
dosssman Jun 25, 2022
c494e72
td3 format fixed by pre-cmmit
dosssman Jun 25, 2022
6eef004
cosmatic change: make `handle_timeout_termination` explicit
vwxyzjn Jun 29, 2022
db6ff29
Quick fix
vwxyzjn Jun 29, 2022
7a1ab33
update docs
vwxyzjn Jun 29, 2022
4a9425a
Update benchmark
vwxyzjn Jun 29, 2022
302b77a
update docs
vwxyzjn Jun 29, 2022
b261ac1
Update docs
vwxyzjn Jun 29, 2022
48e87d8
DDPG and TD3: got rid of max_action, exploration noise is sampled fro…
dosssman Jun 29, 2022
a6b40b7
Update docs
vwxyzjn Jun 29, 2022
10b606e
Merge branch 'td3_ddpg_action_bound_fix' of https://github.com/dosssm…
vwxyzjn Jun 29, 2022
1461980
Updated TD3 and DDPG regarding the action_mean and action_scale usage
dosssman Jun 29, 2022
18487fa
Merge branch 'td3_ddpg_action_bound_fix' of https://github.com/dosssm…
dosssman Jun 29, 2022
c2ffe83
Reduced needless device copy when sampling action for follouts in DDP…
dosssman Jun 29, 2022
3ed96b9
Quick fix
vwxyzjn Jun 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/ddpg.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
poetry install -E "mujoco pybullet"
python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
--command "poetry run python cleanrl/ddpg_continuous_action.py --track --capture-video" \
--num-seeds 3 \
--workers 3
2 changes: 1 addition & 1 deletion benchmark/td3.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
poetry install -E "mujoco pybullet"
python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
--command "poetry run python cleanrl/td3_continuous_action.py --track --capture-video" \
--num-seeds 3 \
--workers 3
32 changes: 17 additions & 15 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,15 @@ def __init__(self, env):
self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
self.fc2 = nn.Linear(256, 256)
self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))
# action rescaling
self.register_buffer("action_scale", torch.FloatTensor((env.action_space.high - env.action_space.low) / 2.0))
self.register_buffer("action_bias", torch.FloatTensor((env.action_space.high + env.action_space.low) / 2.0))

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return torch.tanh(self.fc_mu(x))
x = torch.tanh(self.fc_mu(x))
return x * self.action_scale + self.action_bias


if __name__ == "__main__":
Expand Down Expand Up @@ -141,7 +145,6 @@ def forward(self, x):
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

max_action = float(envs.single_action_space.high[0])
actor = Actor(envs).to(device)
qf1 = QNetwork(envs).to(device)
qf1_target = QNetwork(envs).to(device)
Expand All @@ -152,7 +155,13 @@ def forward(self, x):
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate)

envs.single_observation_space.dtype = np.float32
rb = ReplayBuffer(args.buffer_size, envs.single_observation_space, envs.single_action_space, device=device)
rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
Expand All @@ -162,15 +171,10 @@ def forward(self, x):
if global_step < args.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions = actor(torch.Tensor(obs).to(device))
actions = np.array(
[
(
actions.tolist()[0]
+ np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0])
).clip(envs.single_action_space.low, envs.single_action_space.high)
]
)
with torch.no_grad():
actions = actor(torch.Tensor(obs).to(device))
actions += torch.normal(actor.action_bias, actor.action_scale * args.exploration_noise)
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
Expand All @@ -197,9 +201,7 @@ def forward(self, x):
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)
with torch.no_grad():
next_state_actions = (target_actor(data.next_observations)).clamp(
envs.single_action_space.low[0], envs.single_action_space.high[0]
)
next_state_actions = target_actor(data.next_observations)
qf1_next_target = qf1_target(data.next_observations, next_state_actions)
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (qf1_next_target).view(-1)

Expand Down
9 changes: 2 additions & 7 deletions cleanrl/sac_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def __init__(self, env):
self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))
# action rescaling
self.action_scale = torch.FloatTensor((env.action_space.high - env.action_space.low) / 2.0)
self.action_bias = torch.FloatTensor((env.action_space.high + env.action_space.low) / 2.0)
self.register_buffer("action_scale", torch.FloatTensor((env.action_space.high - env.action_space.low) / 2.0))
self.register_buffer("action_bias", torch.FloatTensor((env.action_space.high + env.action_space.low) / 2.0))

def forward(self, x):
x = F.relu(self.fc1(x))
Expand All @@ -142,11 +142,6 @@ def get_action(self, x):
mean = torch.tanh(mean) * self.action_scale + self.action_bias
return action, log_prob, mean

def to(self, device):
self.action_scale = self.action_scale.to(device)
self.action_bias = self.action_bias.to(device)
return super().to(device)


if __name__ == "__main__":
args = parse_args()
Expand Down
20 changes: 9 additions & 11 deletions cleanrl/td3_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,15 @@ def __init__(self, env):
self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
self.fc2 = nn.Linear(256, 256)
self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))
# action rescaling
self.register_buffer("action_scale", torch.FloatTensor((env.action_space.high - env.action_space.low) / 2.0))
self.register_buffer("action_bias", torch.FloatTensor((env.action_space.high + env.action_space.low) / 2.0))

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return torch.tanh(self.fc_mu(x))
x = torch.tanh(self.fc_mu(x))
return x * self.action_scale + self.action_bias


if __name__ == "__main__":
Expand Down Expand Up @@ -143,7 +147,6 @@ def forward(self, x):
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

max_action = float(envs.single_action_space.high[0])
actor = Actor(envs).to(device)
qf1 = QNetwork(envs).to(device)
qf2 = QNetwork(envs).to(device)
Expand Down Expand Up @@ -173,15 +176,10 @@ def forward(self, x):
if global_step < args.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions = actor(torch.Tensor(obs).to(device))
actions = np.array(
[
(
actions.tolist()[0]
+ np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0])
).clip(envs.single_action_space.low, envs.single_action_space.high)
]
)
with torch.no_grad():
actions = actor(torch.Tensor(obs).to(device))
actions += torch.normal(actor.action_bias, actor.action_scale * args.exploration_noise)
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
Expand Down
70 changes: 57 additions & 13 deletions docs/rl-algorithms/ddpg.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,54 @@ Our [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master
```

1. [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py) uses `--batch-size=256 --tau=0.005`, while (Lillicrap et al., 2016, see Appendix 7 EXPERIMENT DETAILS)[^1] uses `--batch-size=64 --tau=0.001`
<!--
1. Vectorized architecture (:material-github: [common/cmd_util.py#L22](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/cmd_util.py#L22))
1. Orthogonal Initialization of Weights and Constant Initialization of biases (:material-github: [a2c/utils.py#L58)](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/a2c/utils.py#L58))
1. The Adam Optimizer's Epsilon Parameter (:material-github: [ppo2/model.py#L100](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L100))
1. Adam Learning Rate Annealing (:material-github: [ppo2/ppo2.py#L133-L135](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/ppo2.py#L133-L135))
1. Generalized Advantage Estimation (:material-github: [ppo2/runner.py#L56-L65](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/runner.py#L56-L65))
1. Mini-batch Updates (:material-github: [ppo2/ppo2.py#L157-L166](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/ppo2.py#L157-L166))
1. Normalization of Advantages (:material-github: [ppo2/model.py#L139](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L139))
1. Clipped surrogate objective (:material-github: [ppo2/model.py#L81-L86](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L81-L86))
1. Value Function Loss Clipping (:material-github: [ppo2/model.py#L68-L75](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L68-L75))
1. Overall Loss and Entropy Bonus (:material-github: [ppo2/model.py#L91](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L91))
1. Global Gradient Clipping (:material-github: [ppo2/model.py#L102-L108](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L102-L108)) -->

1. [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py) also adds support for handling continuous environments where the lower and higher bounds of the action space are not $[-1,1]$, or are asymmetric.
The case where the bounds are not $[-1,1]$ is handled in [`DDPG.py`](https://github.com/sfujim/TD3/blob/385b33ac7de4767bab17eb02ade4a268d3e4e24f/DDPG.py#L15) (Fujimoto et al., 2018)[^2] as follows:
```python
class Actor(nn.Module):

...

def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a)) # Scale from [-1,1] to [-action_high, action_high]
```
On the other hand, in [`CleanRL's ddpg_continuous_action.py`](https://github.com/dosssman/cleanrl/blob/10b606e7bd9bd1b06e455e8ef542df2b7699a20c/cleanrl/ddpg_continuous_action.py#L98), the mean and the scale of the the action space are computed as `action_bias` and `action_scale` respectively.
Those scalars are in turn used to scale the output of a `tanh` activation function in the actor to the original action space range:
```python
class Actor(nn.Module):
def __init__(self, env):
...
# action rescaling
self.register_buffer("action_scale", torch.FloatTensor((env.action_space.high - env.action_space.low) / 2.0))
self.register_buffer("action_bias", torch.FloatTensor((env.action_space.high + env.action_space.low) / 2.0))

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = torch.tanh(self.fc_mu(x))
return x * self.action_scale + self.action_bias # Scale from [-1,1] to [-action_low, action_high]
```

Additionally, when drawing exploration noise that is added to the actions produced by the actor, [`CleanRL's ddpg_continuous_action.py`](https://github.com/dosssman/cleanrl/blob/10b606e7bd9bd1b06e455e8ef542df2b7699a20c/cleanrl/ddpg_continuous_action.py#L175) centers the distribution the sampled from at `action_bias`, and the scale of the distribution is set to `action_scale * exploration_noise`.

???+ info

Note that `Humanoid-v2`, `InvertedPendulum-v2`, `Pusher-v2` have action space bounds that are not the standard `[-1, 1]`. See below.

```
Ant-v2 Observation space: Box(-inf, inf, (111,), float64) Action space: Box(-1.0, 1.0, (8,), float32)
HalfCheetah-v2 Observation space: Box(-inf, inf, (17,), float64) Action space: Box(-1.0, 1.0, (6,), float32)
Hopper-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (3,), float32)
Humanoid-v2 Observation space: Box(-inf, inf, (376,), float64) Action space: Box(-0.4, 0.4, (17,), float32)
InvertedDoublePendulum-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (1,), float32)
InvertedPendulum-v2 Observation space: Box(-inf, inf, (4,), float64) Action space: Box(-3.0, 3.0, (1,), float32)
Pusher-v2 Observation space: Box(-inf, inf, (23,), float64) Action space: Box(-2.0, 2.0, (7,), float32)
Reacher-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (2,), float32)
Swimmer-v2 Observation space: Box(-inf, inf, (8,), float64) Action space: Box(-1.0, 1.0, (2,), float32)
Walker2d-v2 Observation space: Box(-inf, inf, (17,), float64) Action space: Box(-1.0, 1.0, (6,), float32)
```


### Experiment results
Expand All @@ -172,7 +208,9 @@ Below are the average episodic returns for [`ddpg_continuous_action.py`](https:/
| HalfCheetah | 9382.32 ± 1395.52 |8577.29 | 3305.60|
| Walker2d | 1598.35 ± 862.66 | 3098.11 | 1843.85 |
| Hopper | 1313.43 ± 684.46 | 1860.02 | 2020.46 |

| Humanoid | 897.74 ± 281.87 | not available |
| Pusher | -34.45 ± 4.47 | not available |
| InvertedPendulum | 645.67 ± 270.31 | 1000.00 ± 0.00 |


???+ info
Expand All @@ -191,6 +229,12 @@ Learning curves:
<img src="../ddpg/Walker2d-v2.png">

<img src="../ddpg/Hopper-v2.png">

<img src="../ddpg/Humanoid-v2.png">

<img src="../ddpg/Pusher-v2.png">

<img src="../ddpg/InvertedPendulum-v2.png">
</div>


Expand Down
Binary file added docs/rl-algorithms/ddpg/Humanoid-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/ddpg/InvertedPendulum-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/ddpg/Pusher-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
56 changes: 56 additions & 0 deletions docs/rl-algorithms/td3.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,53 @@ Our [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/

1. [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py) uses a two separate objects `qf1` and `qf2` to represents the two Q functions in the Clipped Double Q-learning architecture, whereas [`TD3.py`](https://github.com/sfujim/TD3/blob/master/TD3.py) (Fujimoto et al., 2018)[^2] uses a single `Critic` class that contains both Q networks. That said, these two implementations are virtually the same.

2. [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py) also adds support for handling continuous environments where the lower and higher bounds of the action space are not $[-1,1]$, or are asymmetric.
The case where the bounds are not $[-1,1]$ is handled in [`TD3.py`](https://github.com/sfujim/TD3/blob/385b33ac7de4767bab17eb02ade4a268d3e4e24f/TD3.py#L28) (Fujimoto et al., 2018)[^2] as follows:
```python
class Actor(nn.Module):

...

def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a)) # Scale from [-1,1] to [-action_high, action_high]
```
On the other hand, in [`CleanRL's td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py), the mean and the scale of the action space are computed as `action_bias` and `action_scale` respectively.
Those scalars are in turn used to scale the output of a `tanh` activation function in the actor to the original action space range:
```python
class Actor(nn.Module):
def __init__(self, env):
...
# action rescaling
self.register_buffer("action_scale", torch.FloatTensor((env.action_space.high - env.action_space.low) / 2.0))
self.register_buffer("action_bias", torch.FloatTensor((env.action_space.high + env.action_space.low) / 2.0))

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = torch.tanh(self.fc_mu(x))
return x * self.action_scale + self.action_bias # Scale from [-1,1] to [-action_low, action_high]
```

Additionally, when drawing exploration noise that is added to the actions produced by the actor, [`CleanRL's td3_continuous_action.py`](https://github.com/dosssman/cleanrl/blob/10b606e7bd9bd1b06e455e8ef542df2b7699a20c/cleanrl/td3_continuous_action.py#L180) centers the distribution the sampled from at `action_bias`, and the scale of the distribution is set to `action_scale * exploration_noise`.

???+ info

Note that `Humanoid-v2`, `InvertedPendulum-v2`, `Pusher-v2` have action space bounds that are not the standard `[-1, 1]`. See below and :material-github: [PR #196](https://github.com/vwxyzjn/cleanrl/issues/196)

```
Ant-v2 Observation space: Box(-inf, inf, (111,), float64) Action space: Box(-1.0, 1.0, (8,), float32)
HalfCheetah-v2 Observation space: Box(-inf, inf, (17,), float64) Action space: Box(-1.0, 1.0, (6,), float32)
Hopper-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (3,), float32)
Humanoid-v2 Observation space: Box(-inf, inf, (376,), float64) Action space: Box(-0.4, 0.4, (17,), float32)
InvertedDoublePendulum-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (1,), float32)
InvertedPendulum-v2 Observation space: Box(-inf, inf, (4,), float64) Action space: Box(-3.0, 3.0, (1,), float32)
Pusher-v2 Observation space: Box(-inf, inf, (23,), float64) Action space: Box(-2.0, 2.0, (7,), float32)
Reacher-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (2,), float32)
Swimmer-v2 Observation space: Box(-inf, inf, (8,), float64) Action space: Box(-1.0, 1.0, (2,), float32)
Walker2d-v2 Observation space: Box(-inf, inf, (17,), float64) Action space: Box(-1.0, 1.0, (6,), float32)
```

### Experiment results

Expand All @@ -79,6 +126,9 @@ Below are the average episodic returns for [`td3_continuous_action.py`](https://
| HalfCheetah | 9018.31 ± 1078.31 |9636.95 ± 859.065 |
| Walker2d | 4246.07 ± 1210.84 | 4682.82 ± 539.64 |
| Hopper | 3391.78 ± 232.21 | 3564.07 ± 114.74 |
| Humanoid | 4822.64 ± 321.85 | not available |
| Pusher | -42.24 ± 6.74 | not available |
| InvertedPendulum | 964.59 ± 43.91 | 1000.00 ± 0.00 |



Expand All @@ -98,6 +148,12 @@ Learning curves:
<img src="../td3/Walker2d-v2.png">

<img src="../td3/Hopper-v2.png">

<img src="../td3/Humanoid-v2.png">

<img src="../td3/Pusher-v2.png">

<img src="../td3/InvertedPendulum-v2.png">
</div>


Expand Down
Binary file added docs/rl-algorithms/td3/Humanoid-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/td3/InvertedPendulum-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rl-algorithms/td3/Pusher-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.