@@ -143,8 +143,9 @@ def mutual_information_recursion(
143
143
return_grad:
144
144
Whether to return grads of ``px`` and ``py``, this grad standing for the
145
145
occupation probability is the output of the backward with a
146
- ``fake gradient`` input (all ones) This is useful to implement the
147
- pruned version of rnnt loss.
146
+ ``fake gradient`` the ``fake gradient`` is the same as the gradient
147
+ you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``.
148
+ This is useful to implement the pruned version of rnnt loss.
148
149
149
150
Returns:
150
151
Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual
@@ -160,8 +161,8 @@ def mutual_information_recursion(
160
161
161
162
where we handle edge cases by treating quantities with negative indexes
162
163
as **-infinity**. The extension to cases where the boundaries are
163
- specified should be obvious; it just works on shorter sequences with offsets
164
- into ``px`` and ``py``.
164
+ specified should be obvious; it just works on shorter sequences with
165
+ offsets into ``px`` and ``py``.
165
166
"""
166
167
assert px .ndim == 3
167
168
B , S , T1 = px .shape
@@ -179,10 +180,10 @@ def mutual_information_recursion(
179
180
assert px .is_contiguous ()
180
181
assert py .is_contiguous ()
181
182
182
- m , px_grad , py_grad = MutualInformationRecursionFunction .apply (
183
+ scores , px_grad , py_grad = MutualInformationRecursionFunction .apply (
183
184
px , py , boundary , return_grad
184
185
)
185
- return (m , (px_grad , py_grad )) if return_grad else m
186
+ return (scores , (px_grad , py_grad )) if return_grad else scores
186
187
187
188
188
189
def _inner_product (a : Tensor , b : Tensor ) -> Tensor :
0 commit comments