Skip to content

Commit 0cf0e9e

Browse files
committed
remove tensorflow_probability
1 parent 668ea1d commit 0cf0e9e

File tree

1 file changed

+27
-44
lines changed

1 file changed

+27
-44
lines changed

cleanrl/sac_continuous_action_jax.py

+27-44
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import dataclass
88
from distutils.util import strtobool
99
from functools import partial
10-
from typing import Any, Optional, Sequence
10+
from typing import Sequence
1111

1212
import flax
1313
import flax.linen as nn
@@ -18,7 +18,6 @@
1818
import optax
1919

2020
# import pybullet_envs # noqa
21-
import tensorflow_probability
2221
from flax.training.train_state import TrainState
2322
from stable_baselines3.common.buffers import ReplayBuffer
2423
from stable_baselines3.common.env_util import make_vec_env
@@ -32,9 +31,6 @@
3231
except ImportError:
3332
tqdm = None
3433

35-
tfp = tensorflow_probability.substrates.jax
36-
tfd = tfp.distributions
37-
3834

3935
def parse_args():
4036
# fmt: off
@@ -107,25 +103,6 @@ def thunk():
107103
return thunk
108104

109105

110-
class TanhTransformedDistribution(tfd.TransformedDistribution):
111-
"""
112-
From https://github.com/ikostrikov/walk_in_the_park
113-
otherwise mode is not defined for Squashed Gaussian
114-
"""
115-
116-
def __init__(self, distribution: tfd.Distribution, validate_args: bool = False):
117-
super().__init__(distribution=distribution, bijector=tfp.bijectors.Tanh(), validate_args=validate_args)
118-
119-
def mode(self) -> jnp.ndarray:
120-
return self.bijector.forward(self.distribution.mode())
121-
122-
@classmethod
123-
def _parameter_properties(cls, dtype: Optional[Any], num_classes=None):
124-
td_properties = super()._parameter_properties(dtype, num_classes=num_classes)
125-
del td_properties["bijector"]
126-
return td_properties
127-
128-
129106
class Critic(nn.Module):
130107
n_units: int = 256
131108

@@ -169,18 +146,15 @@ class Actor(nn.Module):
169146
log_std_max: float = 2
170147

171148
@nn.compact
172-
def __call__(self, x: jnp.ndarray) -> tfd.Distribution:
149+
def __call__(self, x: jnp.ndarray):
173150
x = nn.Dense(self.n_units)(x)
174151
x = nn.relu(x)
175152
x = nn.Dense(self.n_units)(x)
176153
x = nn.relu(x)
177154
mean = nn.Dense(self.action_dim)(x)
178155
log_std = nn.Dense(self.action_dim)(x)
179156
log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
180-
dist = TanhTransformedDistribution(
181-
tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)),
182-
)
183-
return dist
157+
return mean, log_std
184158

185159

186160
class RLTrainState(TrainState):
@@ -194,15 +168,17 @@ def sample_action(
194168
observations: jnp.ndarray,
195169
key: jax.random.KeyArray,
196170
) -> jnp.array:
197-
key, noise_key = jax.random.split(key, 2)
198-
dist = actor.apply(actor_state.params, observations)
199-
action = dist.sample(seed=noise_key)
171+
key, subkey = jax.random.split(key, 2)
172+
action_mean, action_logstd = actor.apply(actor_state.params, observations)
173+
action_std = jnp.exp(action_logstd)
174+
action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
175+
action = jnp.tanh(action)
200176
return action, key
201177

202178

203179
@partial(jax.jit, static_argnames="actor")
204180
def select_action(actor: Actor, actor_state: TrainState, observations: jnp.ndarray) -> jnp.array:
205-
return actor.apply(actor_state.params, observations).mode()
181+
return actor.apply(actor_state.params, observations)[0]
206182

207183

208184
def scale_action(action_space: gym.spaces.Box, action: np.ndarray) -> np.ndarray:
@@ -352,12 +328,17 @@ def update_critic(
352328
dones: np.ndarray,
353329
key: jax.random.KeyArray,
354330
):
355-
key, noise_key = jax.random.split(key, 2)
331+
key, subkey = jax.random.split(key, 2)
332+
action_mean, action_logstd = actor.apply(actor_state.params, next_observations)
356333
# sample action from the actor
357-
dist = actor.apply(actor_state.params, next_observations)
358-
next_state_actions = dist.sample(seed=noise_key)
359-
next_log_prob = dist.log_prob(next_state_actions)
360-
334+
action_std = jnp.exp(action_logstd)
335+
next_state_actions = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
336+
next_log_prob = (
337+
-0.5 * ((next_state_actions - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
338+
)
339+
next_state_actions = jnp.tanh(next_state_actions)
340+
next_log_prob -= jnp.log((1 - jnp.power(next_state_actions, 2)) + 1e-6)
341+
next_log_prob = next_log_prob.sum(axis=1)
361342
qf_next_values = qf.apply(qf_state.target_params, next_observations, next_state_actions)
362343

363344
next_q_values = jnp.min(qf_next_values, axis=0)
@@ -388,14 +369,16 @@ def update_actor(
388369
observations: np.ndarray,
389370
key: jax.random.KeyArray,
390371
):
391-
key, noise_key = jax.random.split(key, 2)
372+
key, subkey = jax.random.split(key, 2)
392373

393374
def actor_loss(params):
394-
395-
dist = actor.apply(params, observations)
396-
actor_actions = dist.sample(seed=noise_key)
397-
log_prob = dist.log_prob(actor_actions).reshape(-1, 1)
398-
375+
action_mean, action_logstd = actor.apply(params, observations)
376+
action_std = jnp.exp(action_logstd)
377+
actor_actions = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
378+
log_prob = -0.5 * ((actor_actions - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
379+
actor_actions = jnp.tanh(actor_actions)
380+
log_prob -= jnp.log((1 - jnp.power(actor_actions, 2)) + 1e-6)
381+
log_prob = log_prob.sum(axis=1, keepdims=True)
399382
qf_pi = qf.apply(qf_state.params, observations, actor_actions)
400383
# Take min among all critics
401384
min_qf_pi = jnp.min(qf_pi, axis=0)

0 commit comments

Comments
 (0)