Skip to content

Commit 0feefc7

Browse files
authored
Modified rnnt (#902)
* Add modified mutual_information_recursion * Add modified rnnt loss * Using more efficient way to fix boundaries * Fix modified pruned rnnt loss * Fix the s_begin constrains of pruned loss for modified version transducer
1 parent 3cc74f1 commit 0feefc7

7 files changed

+682
-540
lines changed

k2/python/csrc/torch/mutual_information.h

+17-12
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ namespace k2 {
3333
in mutual_information.py. This is the core recursion
3434
in the sequence-to-sequence mutual information computation.
3535
36-
@param px Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
37-
generating the next x in the sequence, i.e.
38-
xy[b][s][t] is the log of
39-
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
40-
i.e. the log-prob of generating x_s given subsequences of
41-
lengths (s, t), divided by the prior probability of generating
42-
x_s. (See mutual_information.py for more info).
36+
@param px Tensor of shape [B][S][T + 1] if not modified, [B][S][T] if
37+
modified. `modified` can be worked out from this. In not-modified case,
38+
it can be thought of as the log-odds ratio of generating the next x in
39+
the sequence, i.e.
40+
xy[b][s][t] is the log of
41+
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
42+
i.e. the log-prob of generating x_s given subsequences of
43+
lengths (s, t), divided by the prior probability of generating x_s.
44+
(See mutual_information.py for more info).
4345
@param py The log-odds ratio of generating the next y in the sequence.
4446
Shape [B][S + 1][T]
4547
@param p This function writes to p[b][s][t] the mutual information between
@@ -49,10 +51,13 @@ namespace k2 {
4951
in the case where s_begin == t_begin == 0:
5052
5153
p[b,0,0] = 0.0
52-
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
54+
if not modified:
55+
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
5356
p[b,s,t-1] + py[b,s,t-1])
54-
if s > 0 or t > 0,
55-
treating values with any -1 index as -infinity.
57+
if modified:
58+
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
59+
p[b,s,t-1] + py[b,s,t-1])
60+
... treating values with any -1 index as -infinity.
5661
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
5762
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
5863
contains, where for each batch element b, boundary[b]
@@ -79,8 +84,8 @@ torch::Tensor MutualInformationCpu(
7984
torch::Tensor p); // [B][S+1][T+1]; an output
8085

8186
torch::Tensor MutualInformationCuda(
82-
torch::Tensor px, // [B][S][T+1]
83-
torch::Tensor py, // [B][S+1][T]
87+
torch::Tensor px, // [B][S][T+1] if !modified, [B][S][T] if modified.
88+
torch::Tensor py, // [B][S+1][T]
8489
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
8590
torch::Tensor p); // [B][S+1][T+1]; an output
8691

k2/python/csrc/torch/mutual_information_cpu.cu

+69-27
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,30 @@
2323

2424
namespace k2 {
2525

26-
// forward of mutual_information. See also comment of `mutual_information`
26+
// forward of mutual_information. See """... """ comment of
27+
// `mutual_information_recursion` in
2728
// in k2/python/k2/mutual_information.py for documentation of the
2829
// behavior of this function.
30+
31+
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
32+
// `modified` from this.
33+
// py: of shape [B, S+1, T]
34+
// boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end)
35+
// defaulting to (0, 0, S, T).
36+
// p: of shape (S+1, T+1)
37+
// Computes the recursion:
38+
// if !modified:
39+
// p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
40+
// p[b,s,t-1] + py[b,s,t-1])
41+
// if modified:
42+
// p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
43+
// p[b,s,t-1] + py[b,s,t-1])
44+
45+
// .. treating out-of-range elements as -infinity and with special cases:
46+
// p[b, s_begin, t_begin] = 0.0
47+
//
48+
// and this function returns a tensor of shape (B,) consisting of elements
49+
// p[b, s_end, t_end]
2950
torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
3051
torch::optional<torch::Tensor> opt_boundary,
3152
torch::Tensor p) {
@@ -36,10 +57,13 @@ torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
3657
px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(),
3758
"inputs must be CPU tensors");
3859

60+
bool modified = (px.size(2) == py.size(2));
61+
3962
auto scalar_t = px.scalar_type();
4063
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
4164

42-
const int B = px.size(0), S = px.size(1), T = px.size(2) - 1;
65+
const int B = px.size(0), S = px.size(1), T = py.size(2);
66+
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
4367
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
4468
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
4569

@@ -61,28 +85,36 @@ torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
6185
auto boundary_a = boundary.accessor<int64_t, 2>();
6286
auto ans_a = ans.accessor<scalar_t, 1>();
6387

88+
int t_offset = (modified ? -1 : 0);
6489
for (int b = 0; b < B; b++) {
6590
int s_begin = boundary_a[b][0];
6691
int t_begin = boundary_a[b][1];
6792
int s_end = boundary_a[b][2];
6893
int t_end = boundary_a[b][3];
6994
p_a[b][s_begin][t_begin] = 0.0;
70-
for (int s = s_begin + 1; s <= s_end; ++s)
71-
p_a[b][s][t_begin] =
72-
p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
95+
if (modified) {
96+
for (int s = s_begin + 1; s <= s_end; ++s)
97+
p_a[b][s][t_begin] = -std::numeric_limits<scalar_t>::infinity();
98+
} else {
99+
// note: t_offset = 0 so don't need t_begin + t_offset below.
100+
for (int s = s_begin + 1; s <= s_end; ++s)
101+
p_a[b][s][t_begin] =
102+
p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
103+
}
73104
for (int t = t_begin + 1; t <= t_end; ++t)
74105
p_a[b][s_begin][t] =
75106
p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
76107
for (int s = s_begin + 1; s <= s_end; ++s) {
77108
scalar_t p_s_t1 = p_a[b][s][t_begin];
78109
for (int t = t_begin + 1; t <= t_end; ++t) {
79110
// The following statement is a small optimization of:
80-
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
81-
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
111+
// p_a[b][s][t] = LogAdd(
112+
// p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset],
113+
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
82114
// .. which obtains p_a[b][s][t - 1] from a register.
83-
p_a[b][s][t] = p_s_t1 =
84-
LogAdd<scalar_t>()(p_a[b][s - 1][t] + px_a[b][s - 1][t],
85-
p_s_t1 + py_a[b][s][t - 1]);
115+
p_a[b][s][t] = p_s_t1 = LogAdd<scalar_t>()(
116+
p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
117+
p_s_t1 + py_a[b][s][t - 1]);
86118
}
87119
}
88120
ans_a[b] = p_a[b][s_end][t_end];
@@ -102,15 +134,18 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
102134
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
103135
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
104136

137+
bool modified = (px.size(2) == py.size(2));
138+
105139
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
106140
p.device().is_cpu() && ans_grad.device().is_cpu(),
107141
"inputs must be CPU tensors");
108142

109143
auto scalar_t = px.scalar_type();
110144
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
111145

112-
const int B = px.size(0), S = px.size(1), T = px.size(2) - 1;
113-
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
146+
const int B = px.size(0), S = px.size(1), T = py.size(2);
147+
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
148+
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1);
114149
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
115150

116151
auto boundary = opt_boundary.value_or(
@@ -123,9 +158,10 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
123158
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
124159

125160
bool has_boundary = opt_boundary.has_value();
161+
int T1 = T + (modified ? 0 : 1);
126162
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
127-
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts)
128-
: torch::empty({B, S, T + 1}, opts)),
163+
px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts)
164+
: torch::empty({B, S, T1}, opts)),
129165
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts)
130166
: torch::empty({B, S + 1, T}, opts));
131167

@@ -138,6 +174,7 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
138174

139175
auto ans_grad_a = ans_grad.accessor<scalar_t, 1>();
140176
auto boundary_a = boundary.accessor<int64_t, 2>();
177+
int t_offset = (modified ? -1 : 0);
141178

142179
for (int b = 0; b < B; b++) {
143180
int s_begin = boundary_a[b][0];
@@ -151,10 +188,12 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
151188
for (int t = t_end; t > t_begin; --t) {
152189
// The s,t indexes correspond to
153190
// The statement we are backpropagating here is:
154-
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
155-
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
191+
// p_a[b][s][t] = LogAdd(
192+
// p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
193+
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
156194
// .. which obtains p_a[b][s][t - 1] from a register.
157-
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
195+
scalar_t term1 = p_a[b][s - 1][t + t_offset] +
196+
px_a[b][s - 1][t + t_offset],
158197
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
159198
// actually needed..
160199
total = p_a[b][s][t];
@@ -170,8 +209,8 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
170209
// could happen if total == -inf
171210
term1_grad = term2_grad = 0.0;
172211
}
173-
px_grad_a[b][s - 1][t] = term1_grad;
174-
p_grad_a[b][s - 1][t] = term1_grad;
212+
px_grad_a[b][s - 1][t + t_offset] = term1_grad;
213+
p_grad_a[b][s - 1][t + t_offset] = term1_grad;
175214
py_grad_a[b][s][t - 1] = term2_grad;
176215
p_grad_a[b][s][t - 1] += term2_grad;
177216
}
@@ -184,14 +223,17 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
184223
p_grad_a[b][s_begin][t - 1] += this_p_grad;
185224
py_grad_a[b][s_begin][t - 1] = this_p_grad;
186225
}
187-
for (int s = s_end; s > s_begin; --s) {
188-
// Backprop for:
189-
// p_a[b][s][t_begin] =
190-
// p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
191-
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
192-
p_grad_a[b][s - 1][t_begin] += this_p_grad;
193-
px_grad_a[b][s - 1][t_begin] = this_p_grad;
194-
}
226+
if (!modified) {
227+
for (int s = s_end; s > s_begin; --s) {
228+
// Backprop for:
229+
// p_a[b][s][t_begin] =
230+
// p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
231+
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
232+
p_grad_a[b][s - 1][t_begin] += this_p_grad;
233+
px_grad_a[b][s - 1][t_begin] = this_p_grad;
234+
}
235+
} // else these were all -infinity's and there is nothing to
236+
// backprop.
195237
// There is no backprop for:
196238
// p_a[b][s_begin][t_begin] = 0.0;
197239
// .. but we can use this for a check, that the grad at the beginning

0 commit comments

Comments
 (0)