-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Copy pathdistributions.py
699 lines (564 loc) · 25.6 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
"""Probability distributions."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal
from stable_baselines3.common.preprocessing import get_action_dim
class Distribution(ABC):
"""Abstract base class for distributions."""
def __init__(self):
super().__init__()
self.distribution = None
@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
"""Create the layers and parameters that represent the distribution.
Subclasses must define this, but the arguments and return type vary between
concrete classes."""
@abstractmethod
def proba_distribution(self, *args, **kwargs) -> "Distribution":
"""Set parameters of the distribution.
:return: self
"""
@abstractmethod
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
Returns the log likelihood
:param x: the taken action
:return: The log likelihood of the distribution
"""
@abstractmethod
def entropy(self) -> Optional[th.Tensor]:
"""
Returns Shannon's entropy of the probability
:return: the entropy, or None if no analytical form is known
"""
@abstractmethod
def sample(self) -> th.Tensor:
"""
Returns a sample from the probability distribution
:return: the stochastic action
"""
@abstractmethod
def mode(self) -> th.Tensor:
"""
Returns the most likely action (deterministic output)
from the probability distribution
:return: the stochastic action
"""
def get_actions(self, deterministic: bool = False) -> th.Tensor:
"""
Return actions according to the probability distribution.
:param deterministic:
:return:
"""
if deterministic:
return self.mode()
return self.sample()
@abstractmethod
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
"""
Returns samples from the probability distribution
given its parameters.
:return: actions
"""
@abstractmethod
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
"""
Returns samples and the associated log probabilities
from the probability distribution given its parameters.
:return: actions and log prob
"""
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
"""
Continuous actions are usually considered to be independent,
so we can sum components of the ``log_prob`` or the entropy.
:param tensor: shape: (n_batch, n_actions) or (n_batch,)
:return: shape: (n_batch,)
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(dim=1)
else:
tensor = tensor.sum()
return tensor
class DiagGaussianDistribution(Distribution):
"""
Gaussian distribution with diagonal covariance matrix, for continuous actions.
:param action_dim: Dimension of the action space.
"""
def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values)
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:return:
"""
mean_actions = nn.Linear(latent_dim, self.action_dim)
# TODO: allow action dependent std
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
return mean_actions, log_std
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
"""
Create the distribution given its parameters (mean, std)
:param mean_actions:
:param log_std:
:return:
"""
action_std = th.ones_like(mean_actions) * log_std.exp()
self.distribution = Normal(mean_actions, action_std)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method.
:param actions:
:return:
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
return self.distribution.rsample()
def mode(self) -> th.Tensor:
return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
:param mean_actions:
:param log_std:
:return:
"""
actions = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(actions)
return actions, log_prob
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
"""
Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
:param action_dim: Dimension of the action space.
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, action_dim: int, epsilon: float = 1e-6):
super().__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
self.gaussian_actions = None
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
super().proba_distribution(mean_actions, log_std)
return self
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
# Inverse tanh
# Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
# We use numpy to avoid numerical instability
if gaussian_actions is None:
# It will be clipped to avoid NaN when inversing tanh
gaussian_actions = TanhBijector.inverse(actions)
# Log likelihood for a Gaussian distribution
log_prob = super().log_prob(gaussian_actions)
# Squash correction (from original SAC implementation)
# this comes from the fact that tanh is bijective and differentiable
log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
return log_prob
def entropy(self) -> Optional[th.Tensor]:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
self.gaussian_actions = super().sample()
return th.tanh(self.gaussian_actions)
def mode(self) -> th.Tensor:
self.gaussian_actions = super().mode()
# Squash the output
return th.tanh(self.gaussian_actions)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
class CategoricalDistribution(Distribution):
"""
Categorical distribution for discrete actions.
:param action_dim: Number of discrete actions
"""
def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Categorical distribution.
You can then get probabilities using a softmax.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution":
self.distribution = Categorical(logits=action_logits)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)
def entropy(self) -> th.Tensor:
return self.distribution.entropy()
def sample(self) -> th.Tensor:
return self.distribution.sample()
def mode(self) -> th.Tensor:
return th.argmax(self.distribution.probs, dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class MultiCategoricalDistribution(Distribution):
"""
MultiCategorical distribution for multi discrete actions.
:param action_dims: List of sizes of discrete action spaces
"""
def __init__(self, action_dims: List[int]):
super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits (flattened) of the MultiCategorical distribution.
You can then get probabilities using a softmax on each sub-space.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution":
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack(
[dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1))], dim=1
).sum(dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1)
def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distribution], dim=1)
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class BernoulliDistribution(Distribution):
"""
Bernoulli distribution for MultiBinary action spaces.
:param action_dim: Number of binary actions
"""
def __init__(self, action_dims: int):
super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Bernoulli distribution.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits
def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution":
self.distribution = Bernoulli(logits=action_logits)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)
def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)
def sample(self) -> th.Tensor:
return self.distribution.sample()
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using generalized State Dependent Exploration (gSDE).
Paper: https://arxiv.org/abs/2005.05719
It is used to create the noise exploration matrix and
compute the log probability of an action with that noise.
:param action_dim: Dimension of the action space.
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,)
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this ensures bounds are satisfied.
:param learn_features: Whether to learn features for gSDE or not.
This will enable gradients to be backpropagated through the features
``latent_sde`` in the code.
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(
self,
action_dim: int,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
learn_features: bool = False,
epsilon: float = 1e-6,
):
super().__init__()
self.action_dim = action_dim
self.latent_sde_dim = None
self.mean_actions = None
self.log_std = None
self.weights_dist = None
self.exploration_mat = None
self.exploration_matrices = None
self._latent_sde = None
self.use_expln = use_expln
self.full_std = full_std
self.epsilon = epsilon
self.learn_features = learn_features
if squash_output:
self.bijector = TanhBijector(epsilon)
else:
self.bijector = None
def get_std(self, log_std: th.Tensor) -> th.Tensor:
"""
Get the standard deviation from the learned parameter
(log of it by default). This ensures that the std is positive.
:param log_std:
:return:
"""
if self.use_expln:
# From gSDE paper, it allows to keep variance
# above zero and prevent it from growing too fast
below_threshold = th.exp(log_std) * (log_std <= 0)
# Avoid NaN: zeros values that are below zero
safe_log_std = log_std * (log_std > 0) + self.epsilon
above_threshold = (th.log1p(safe_log_std) + 1.0) * (log_std > 0)
std = below_threshold + above_threshold
else:
# Use normal exponential
std = th.exp(log_std)
if self.full_std:
return std
# Reduce the number of parameters:
return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None:
"""
Sample weights for the noise exploration matrix,
using a centered Gaussian distribution.
:param log_std:
:param batch_size:
"""
std = self.get_std(log_std)
self.weights_dist = Normal(th.zeros_like(std), std)
# Reparametrization trick to pass gradients
self.exploration_mat = self.weights_dist.rsample()
# Pre-compute matrices in case of parallel exploration
self.exploration_matrices = self.weights_dist.rsample((batch_size,))
def proba_distribution_net(
self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None
) -> Tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the deterministic action, the other parameter will be the
standard deviation of the distribution that control the weights of the noise matrix.
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:param latent_sde_dim: Dimension of the last layer of the features extractor
for gSDE. By default, it is shared with the policy network.
:return:
"""
# Network for the deterministic action, it represents the mean of the distribution
mean_actions_net = nn.Linear(latent_dim, self.action_dim)
# When we learn features for the noise, the feature dimension
# can be different between the policy and the noise network
self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
# Reduce the number of parameters if needed
log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1)
# Transform it to a parameter so it can be optimized
log_std = nn.Parameter(log_std * log_std_init, requires_grad=True)
# Sample an exploration matrix
self.sample_weights(log_std)
return mean_actions_net, log_std
def proba_distribution(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> "StateDependentNoiseDistribution":
"""
Create the distribution given its parameters (mean, std)
:param mean_actions:
:param log_std:
:param latent_sde:
:return:
"""
# Stop gradient if we don't want to influence the features
self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
variance = th.mm(self._latent_sde**2, self.get_std(log_std) ** 2)
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
if self.bijector is not None:
gaussian_actions = self.bijector.inverse(actions)
else:
gaussian_actions = actions
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_actions)
# Sum along action dim
log_prob = sum_independent_dims(log_prob)
if self.bijector is not None:
# Squash correction (from original SAC implementation)
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
return log_prob
def entropy(self) -> Optional[th.Tensor]:
if self.bijector is not None:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def mode(self) -> th.Tensor:
actions = self.distribution.mean
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
# Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
return th.mm(latent_sde, self.exploration_mat)
# Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions)
noise = th.bmm(latent_sde, self.exploration_matrices)
return noise.squeeze(dim=1)
def actions_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std, latent_sde)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(mean_actions, log_std, latent_sde)
log_prob = self.log_prob(actions)
return actions, log_prob
class TanhBijector:
"""
Bijective transformation of a probability distribution
using a squashing function (tanh)
TODO: use Pyro instead (https://pyro.ai/)
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon
@staticmethod
def forward(x: th.Tensor) -> th.Tensor:
return th.tanh(x)
@staticmethod
def atanh(x: th.Tensor) -> th.Tensor:
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
return 0.5 * (x.log1p() - (-x).log1p())
@staticmethod
def inverse(y: th.Tensor) -> th.Tensor:
"""
Inverse tanh.
:param y:
:return:
"""
eps = th.finfo(y.dtype).eps
# Clip the action to avoid NaN
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
# Squash correction (from original SAC implementation)
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
def make_proba_distribution(
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space
:param action_space: the input action space
:param use_sde: Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:param dist_kwargs: Keyword arguments to pass to the probability distribution
:return: the appropriate Distribution object
"""
if dist_kwargs is None:
dist_kwargs = {}
if isinstance(action_space, spaces.Box):
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError(
"Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
)
def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor:
"""
Wrapper for the PyTorch implementation of the full form KL Divergence
:param dist_true: the p distribution
:param dist_pred: the q distribution
:return: KL(dist_true||dist_pred)
"""
# KL Divergence for different distribution types is out of scope
assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
if isinstance(dist_pred, MultiCategoricalDistribution):
assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space"
return th.stack(
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
dim=1,
).sum(dim=1)
# Use the PyTorch kl_divergence implementation
else:
return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)