Skip to content

Commit 47c4b75

Browse files
authored
Fix building doc (#912)
* Fix building doc * Fix flake8
1 parent 779a9bd commit 47c4b75

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

docs/requirements.txt

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
dataclasses
2-
graphviz
3-
recommonmark
4-
sphinx
5-
sphinx-autodoc-typehints
6-
sphinx_rtd_theme
7-
sphinxcontrib-bibtex
1+
dataclasses==0.6
2+
graphviz==0.19.1
3+
recommonmark==0.7.1
4+
sphinx==4.3.2
5+
sphinx-autodoc-typehints==1.12.0
6+
sphinx_rtd_theme==1.0.0
7+
sphinxcontrib-bibtex==2.4.1
88
torch>=1.6.0

k2/python/k2/mutual_information.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,9 @@ def mutual_information_recursion(
143143
return_grad:
144144
Whether to return grads of ``px`` and ``py``, this grad standing for the
145145
occupation probability is the output of the backward with a
146-
``fake gradient`` input (all ones) This is useful to implement the
147-
pruned version of rnnt loss.
146+
``fake gradient`` the ``fake gradient`` is the same as the gradient
147+
you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``.
148+
This is useful to implement the pruned version of rnnt loss.
148149
149150
Returns:
150151
Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual
@@ -160,8 +161,8 @@ def mutual_information_recursion(
160161
161162
where we handle edge cases by treating quantities with negative indexes
162163
as **-infinity**. The extension to cases where the boundaries are
163-
specified should be obvious; it just works on shorter sequences with offsets
164-
into ``px`` and ``py``.
164+
specified should be obvious; it just works on shorter sequences with
165+
offsets into ``px`` and ``py``.
165166
"""
166167
assert px.ndim == 3
167168
B, S, T1 = px.shape
@@ -179,10 +180,10 @@ def mutual_information_recursion(
179180
assert px.is_contiguous()
180181
assert py.is_contiguous()
181182

182-
m, px_grad, py_grad = MutualInformationRecursionFunction.apply(
183+
scores, px_grad, py_grad = MutualInformationRecursionFunction.apply(
183184
px, py, boundary, return_grad
184185
)
185-
return (m, (px_grad, py_grad)) if return_grad else m
186+
return (scores, (px_grad, py_grad)) if return_grad else scores
186187

187188

188189
def _inner_product(a: Tensor, b: Tensor) -> Tensor:

0 commit comments

Comments
 (0)