7
7
from dataclasses import dataclass
8
8
from distutils .util import strtobool
9
9
from functools import partial
10
- from typing import Any , Optional , Sequence
10
+ from typing import Sequence
11
11
12
12
import flax
13
13
import flax .linen as nn
18
18
import optax
19
19
20
20
# import pybullet_envs # noqa
21
- import tensorflow_probability
22
21
from flax .training .train_state import TrainState
23
22
from stable_baselines3 .common .buffers import ReplayBuffer
24
23
from stable_baselines3 .common .env_util import make_vec_env
32
31
except ImportError :
33
32
tqdm = None
34
33
35
- tfp = tensorflow_probability .substrates .jax
36
- tfd = tfp .distributions
37
-
38
34
39
35
def parse_args ():
40
36
# fmt: off
@@ -107,25 +103,6 @@ def thunk():
107
103
return thunk
108
104
109
105
110
- class TanhTransformedDistribution (tfd .TransformedDistribution ):
111
- """
112
- From https://github.com/ikostrikov/walk_in_the_park
113
- otherwise mode is not defined for Squashed Gaussian
114
- """
115
-
116
- def __init__ (self , distribution : tfd .Distribution , validate_args : bool = False ):
117
- super ().__init__ (distribution = distribution , bijector = tfp .bijectors .Tanh (), validate_args = validate_args )
118
-
119
- def mode (self ) -> jnp .ndarray :
120
- return self .bijector .forward (self .distribution .mode ())
121
-
122
- @classmethod
123
- def _parameter_properties (cls , dtype : Optional [Any ], num_classes = None ):
124
- td_properties = super ()._parameter_properties (dtype , num_classes = num_classes )
125
- del td_properties ["bijector" ]
126
- return td_properties
127
-
128
-
129
106
class Critic (nn .Module ):
130
107
n_units : int = 256
131
108
@@ -169,18 +146,15 @@ class Actor(nn.Module):
169
146
log_std_max : float = 2
170
147
171
148
@nn .compact
172
- def __call__ (self , x : jnp .ndarray ) -> tfd . Distribution :
149
+ def __call__ (self , x : jnp .ndarray ):
173
150
x = nn .Dense (self .n_units )(x )
174
151
x = nn .relu (x )
175
152
x = nn .Dense (self .n_units )(x )
176
153
x = nn .relu (x )
177
154
mean = nn .Dense (self .action_dim )(x )
178
155
log_std = nn .Dense (self .action_dim )(x )
179
156
log_std = jnp .clip (log_std , self .log_std_min , self .log_std_max )
180
- dist = TanhTransformedDistribution (
181
- tfd .MultivariateNormalDiag (loc = mean , scale_diag = jnp .exp (log_std )),
182
- )
183
- return dist
157
+ return mean , log_std
184
158
185
159
186
160
class RLTrainState (TrainState ):
@@ -194,15 +168,17 @@ def sample_action(
194
168
observations : jnp .ndarray ,
195
169
key : jax .random .KeyArray ,
196
170
) -> jnp .array :
197
- key , noise_key = jax .random .split (key , 2 )
198
- dist = actor .apply (actor_state .params , observations )
199
- action = dist .sample (seed = noise_key )
171
+ key , subkey = jax .random .split (key , 2 )
172
+ action_mean , action_logstd = actor .apply (actor_state .params , observations )
173
+ action_std = jnp .exp (action_logstd )
174
+ action = action_mean + action_std * jax .random .normal (subkey , shape = action_mean .shape )
175
+ action = jnp .tanh (action )
200
176
return action , key
201
177
202
178
203
179
@partial (jax .jit , static_argnames = "actor" )
204
180
def select_action (actor : Actor , actor_state : TrainState , observations : jnp .ndarray ) -> jnp .array :
205
- return actor .apply (actor_state .params , observations ). mode ()
181
+ return actor .apply (actor_state .params , observations )[ 0 ]
206
182
207
183
208
184
def scale_action (action_space : gym .spaces .Box , action : np .ndarray ) -> np .ndarray :
@@ -352,12 +328,17 @@ def update_critic(
352
328
dones : np .ndarray ,
353
329
key : jax .random .KeyArray ,
354
330
):
355
- key , noise_key = jax .random .split (key , 2 )
331
+ key , subkey = jax .random .split (key , 2 )
332
+ action_mean , action_logstd = actor .apply (actor_state .params , next_observations )
356
333
# sample action from the actor
357
- dist = actor .apply (actor_state .params , next_observations )
358
- next_state_actions = dist .sample (seed = noise_key )
359
- next_log_prob = dist .log_prob (next_state_actions )
360
-
334
+ action_std = jnp .exp (action_logstd )
335
+ next_state_actions = action_mean + action_std * jax .random .normal (subkey , shape = action_mean .shape )
336
+ next_log_prob = (
337
+ - 0.5 * ((next_state_actions - action_mean ) / action_std ) ** 2 - 0.5 * jnp .log (2.0 * jnp .pi ) - action_logstd
338
+ )
339
+ next_state_actions = jnp .tanh (next_state_actions )
340
+ next_log_prob -= jnp .log ((1 - jnp .power (next_state_actions , 2 )) + 1e-6 )
341
+ next_log_prob = next_log_prob .sum (axis = 1 )
361
342
qf_next_values = qf .apply (qf_state .target_params , next_observations , next_state_actions )
362
343
363
344
next_q_values = jnp .min (qf_next_values , axis = 0 )
@@ -388,14 +369,16 @@ def update_actor(
388
369
observations : np .ndarray ,
389
370
key : jax .random .KeyArray ,
390
371
):
391
- key , noise_key = jax .random .split (key , 2 )
372
+ key , subkey = jax .random .split (key , 2 )
392
373
393
374
def actor_loss (params ):
394
-
395
- dist = actor .apply (params , observations )
396
- actor_actions = dist .sample (seed = noise_key )
397
- log_prob = dist .log_prob (actor_actions ).reshape (- 1 , 1 )
398
-
375
+ action_mean , action_logstd = actor .apply (params , observations )
376
+ action_std = jnp .exp (action_logstd )
377
+ actor_actions = action_mean + action_std * jax .random .normal (subkey , shape = action_mean .shape )
378
+ log_prob = - 0.5 * ((actor_actions - action_mean ) / action_std ) ** 2 - 0.5 * jnp .log (2.0 * jnp .pi ) - action_logstd
379
+ actor_actions = jnp .tanh (actor_actions )
380
+ log_prob -= jnp .log ((1 - jnp .power (actor_actions , 2 )) + 1e-6 )
381
+ log_prob = log_prob .sum (axis = 1 , keepdims = True )
399
382
qf_pi = qf .apply (qf_state .params , observations , actor_actions )
400
383
# Take min among all critics
401
384
min_qf_pi = jnp .min (qf_pi , axis = 0 )
0 commit comments