|
2 | 2 | #
|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
| 5 | +import warnings |
5 | 6 | from copy import copy, deepcopy
|
6 | 7 |
|
7 | 8 | import torch
|
8 | 9 | from tensordict import TensorDict, TensorDictBase, unravel_key
|
9 |
| -from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams |
| 10 | +from tensordict.nn import ( |
| 11 | + ProbabilisticTensorDictModule, |
| 12 | + ProbabilisticTensorDictSequential, |
| 13 | + TensorDictModuleBase, |
| 14 | + TensorDictParams, |
| 15 | + TensorDictSequential, |
| 16 | +) |
10 | 17 | from tensordict.utils import is_seq_of_nested_key
|
11 | 18 | from torch import nn
|
12 | 19 | from torchrl.data.tensor_specs import Composite, Unbounded
|
13 | 20 | from torchrl.envs.transforms.transforms import Transform
|
14 | 21 | from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
|
15 | 22 |
|
| 23 | +# TODO: This should live somewhere else |
| 24 | +class ReferenceModelLogProbTransform(Transform): |
| 25 | + """A transform to compute and store the log-probabilities from the reference model.""" |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + frozen_model: ProbabilisticTensorDictModule, |
| 30 | + ): |
| 31 | + super().__init__(in_keys=frozen_model.in_keys, out_keys=frozen_model.out_keys) |
| 32 | + self.frozen_model: ProbabilisticTensorDictModule = frozen_model |
| 33 | + |
| 34 | + def _call(self, inputs: TensorDict) -> TensorDict: |
| 35 | + # Compute the log-prob given the reference model |
| 36 | + return self.frozen_model(inputs) |
| 37 | + |
| 38 | +class KLDivergenceTransform(Transform): |
| 39 | + """A transform to compute the KL divergence between the current and reference policies.""" |
| 40 | + |
| 41 | + ... |
| 42 | + |
| 43 | + |
| 44 | +class RewardAdjustmentTransform(Transform): |
| 45 | + """A transform to adjust the reward based on the computed KL divergence.""" |
| 46 | + |
| 47 | + ... |
| 48 | + |
| 49 | + |
| 50 | +class KLConstrainedTransform(Composite): |
| 51 | + """A composite transform to apply KL-based constraints on the policy.""" |
| 52 | + |
| 53 | + ... |
| 54 | + |
16 | 55 |
|
17 | 56 | class KLRewardTransform(Transform):
|
18 |
| - """A transform to add a KL[pi_current||pi_0] correction term to the reward. |
| 57 | + r"""A transform to add a KL divergence correction term to the reward. |
19 | 58 |
|
20 | 59 | This transform is used to constrain the policy to remain close to its original
|
21 |
| - configuration which limits overfitting when fine-tuning using RLHF. |
| 60 | + configuration, which helps limit overfitting when fine-tuning using Reinforcement Learning with Human Feedback |
| 61 | + (RLHF) or other forms of post-training (e.g., GRPO). |
| 62 | + The KL divergence between the current policy distribution and the reference policy distribution is used to adjust the reward: |
| 63 | +
|
| 64 | + .. math:: |
| 65 | +
|
| 66 | + R_{\text{adjusted}} = R - \text{coef} \times \text{KL}(\pi_{\text{current}} || \pi_0) |
| 67 | +
|
| 68 | + where \( R_{\text{adjusted}} \) is the adjusted reward, \( R \) is the original reward, and |
| 69 | + \(\text{KL}(\pi_{\text{current}} || \pi_0)\) is the Kullback-Leibler divergence between the current policy |
| 70 | + distribution \( \pi_{\text{current}} \) and the reference policy distribution \( \pi_0 \). |
| 71 | +
|
| 72 | + The KL divergence can be estimated using the difference in log probabilities of the actions: |
| 73 | +
|
| 74 | + .. math:: |
| 75 | +
|
| 76 | + \text{KL}(\pi_{\text{current}} || \pi_0) \approx \log p(a \mid \theta_{\text{current}}) - \log p(a \mid \theta_0) |
| 77 | +
|
| 78 | + where \( \log p(a \mid \theta_{\text{current}}) \) is the log probability of action \( a \) under the current model, and |
| 79 | + \( \log p(a \mid \theta_0) \) is the log probability of action \( a \) under the reference model. |
| 80 | +
|
22 | 81 |
|
23 | 82 | Args:
|
24 | 83 | actor (ProbabilisticTensorDictModule): a probabilistic actor. It must
|
@@ -86,6 +145,11 @@ def __init__(
|
86 | 145 | out_keys=None,
|
87 | 146 | requires_grad=False,
|
88 | 147 | ):
|
| 148 | + warnings.warn( |
| 149 | + "This class will be removed in a future release (v0.10.0). Please use torchrl.envs.KLConstrainedTransform " |
| 150 | + "instead.", |
| 151 | + category=FutureWarning, |
| 152 | + ) |
89 | 153 | if in_keys is None:
|
90 | 154 | in_keys = self.DEFAULT_IN_KEYS
|
91 | 155 | if out_keys is None:
|
@@ -160,7 +224,7 @@ def _reset(
|
160 | 224 |
|
161 | 225 | def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
|
162 | 226 | # run the actor on the tensordict
|
163 |
| - action = next_tensordict.get("action", None) |
| 227 | + action = next_tensordict.get("action") |
164 | 228 | if action is None:
|
165 | 229 | # being called after reset or without action, skipping
|
166 | 230 | if self.out_keys[0] != ("reward",) and self.parent is not None:
|
|
0 commit comments