Skip to content

Commit 9b8056d

Browse files
committed
Update
[ghstack-poisoned]
1 parent 055b0c1 commit 9b8056d

File tree

1 file changed

+68
-4
lines changed

1 file changed

+68
-4
lines changed

torchrl/envs/transforms/rlhf.py

+68-4
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,82 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
import warnings
56
from copy import copy, deepcopy
67

78
import torch
89
from tensordict import TensorDict, TensorDictBase, unravel_key
9-
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
10+
from tensordict.nn import (
11+
ProbabilisticTensorDictModule,
12+
ProbabilisticTensorDictSequential,
13+
TensorDictModuleBase,
14+
TensorDictParams,
15+
TensorDictSequential,
16+
)
1017
from tensordict.utils import is_seq_of_nested_key
1118
from torch import nn
1219
from torchrl.data.tensor_specs import Composite, Unbounded
1320
from torchrl.envs.transforms.transforms import Transform
1421
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
1522

23+
# TODO: This should live somewhere else
24+
class ReferenceModelLogProbTransform(Transform):
25+
"""A transform to compute and store the log-probabilities from the reference model."""
26+
27+
def __init__(
28+
self,
29+
frozen_model: ProbabilisticTensorDictModule,
30+
):
31+
super().__init__(in_keys=frozen_model.in_keys, out_keys=frozen_model.out_keys)
32+
self.frozen_model: ProbabilisticTensorDictModule = frozen_model
33+
34+
def _call(self, inputs: TensorDict) -> TensorDict:
35+
# Compute the log-prob given the reference model
36+
return self.frozen_model(inputs)
37+
38+
class KLDivergenceTransform(Transform):
39+
"""A transform to compute the KL divergence between the current and reference policies."""
40+
41+
...
42+
43+
44+
class RewardAdjustmentTransform(Transform):
45+
"""A transform to adjust the reward based on the computed KL divergence."""
46+
47+
...
48+
49+
50+
class KLConstrainedTransform(Composite):
51+
"""A composite transform to apply KL-based constraints on the policy."""
52+
53+
...
54+
1655

1756
class KLRewardTransform(Transform):
18-
"""A transform to add a KL[pi_current||pi_0] correction term to the reward.
57+
r"""A transform to add a KL divergence correction term to the reward.
1958
2059
This transform is used to constrain the policy to remain close to its original
21-
configuration which limits overfitting when fine-tuning using RLHF.
60+
configuration, which helps limit overfitting when fine-tuning using Reinforcement Learning with Human Feedback
61+
(RLHF) or other forms of post-training (e.g., GRPO).
62+
The KL divergence between the current policy distribution and the reference policy distribution is used to adjust the reward:
63+
64+
.. math::
65+
66+
R_{\text{adjusted}} = R - \text{coef} \times \text{KL}(\pi_{\text{current}} || \pi_0)
67+
68+
where \( R_{\text{adjusted}} \) is the adjusted reward, \( R \) is the original reward, and
69+
\(\text{KL}(\pi_{\text{current}} || \pi_0)\) is the Kullback-Leibler divergence between the current policy
70+
distribution \( \pi_{\text{current}} \) and the reference policy distribution \( \pi_0 \).
71+
72+
The KL divergence can be estimated using the difference in log probabilities of the actions:
73+
74+
.. math::
75+
76+
\text{KL}(\pi_{\text{current}} || \pi_0) \approx \log p(a \mid \theta_{\text{current}}) - \log p(a \mid \theta_0)
77+
78+
where \( \log p(a \mid \theta_{\text{current}}) \) is the log probability of action \( a \) under the current model, and
79+
\( \log p(a \mid \theta_0) \) is the log probability of action \( a \) under the reference model.
80+
2281
2382
Args:
2483
actor (ProbabilisticTensorDictModule): a probabilistic actor. It must
@@ -86,6 +145,11 @@ def __init__(
86145
out_keys=None,
87146
requires_grad=False,
88147
):
148+
warnings.warn(
149+
"This class will be removed in a future release (v0.10.0). Please use torchrl.envs.KLConstrainedTransform "
150+
"instead.",
151+
category=FutureWarning,
152+
)
89153
if in_keys is None:
90154
in_keys = self.DEFAULT_IN_KEYS
91155
if out_keys is None:
@@ -160,7 +224,7 @@ def _reset(
160224

161225
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
162226
# run the actor on the tensordict
163-
action = next_tensordict.get("action", None)
227+
action = next_tensordict.get("action")
164228
if action is None:
165229
# being called after reset or without action, skipping
166230
if self.out_keys[0] != ("reward",) and self.parent is not None:

0 commit comments

Comments
 (0)