23
23
24
24
namespace k2 {
25
25
26
- // forward of mutual_information. See also comment of `mutual_information`
26
+ // forward of mutual_information. See """... """ comment of
27
+ // `mutual_information_recursion` in
27
28
// in k2/python/k2/mutual_information.py for documentation of the
28
29
// behavior of this function.
30
+
31
+ // px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
32
+ // `modified` from this.
33
+ // py: of shape [B, S+1, T]
34
+ // boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end)
35
+ // defaulting to (0, 0, S, T).
36
+ // p: of shape (S+1, T+1)
37
+ // Computes the recursion:
38
+ // if !modified:
39
+ // p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
40
+ // p[b,s,t-1] + py[b,s,t-1])
41
+ // if modified:
42
+ // p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
43
+ // p[b,s,t-1] + py[b,s,t-1])
44
+
45
+ // .. treating out-of-range elements as -infinity and with special cases:
46
+ // p[b, s_begin, t_begin] = 0.0
47
+ //
48
+ // and this function returns a tensor of shape (B,) consisting of elements
49
+ // p[b, s_end, t_end]
29
50
torch::Tensor MutualInformationCpu (torch::Tensor px, torch::Tensor py,
30
51
torch::optional<torch::Tensor> opt_boundary,
31
52
torch::Tensor p) {
@@ -36,10 +57,13 @@ torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
36
57
px.device ().is_cpu () && py.device ().is_cpu () && p.device ().is_cpu (),
37
58
" inputs must be CPU tensors" );
38
59
60
+ bool modified = (px.size (2 ) == py.size (2 ));
61
+
39
62
auto scalar_t = px.scalar_type ();
40
63
auto opts = torch::TensorOptions ().dtype (scalar_t ).device (px.device ());
41
64
42
- const int B = px.size (0 ), S = px.size (1 ), T = px.size (2 ) - 1 ;
65
+ const int B = px.size (0 ), S = px.size (1 ), T = py.size (2 );
66
+ TORCH_CHECK (px.size (2 ) == (modified ? T : T + 1 ));
43
67
TORCH_CHECK (py.size (0 ) == B && py.size (1 ) == S + 1 && py.size (2 ) == T);
44
68
TORCH_CHECK (p.size (0 ) == B && p.size (1 ) == S + 1 && p.size (2 ) == T + 1 );
45
69
@@ -61,28 +85,36 @@ torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
61
85
auto boundary_a = boundary.accessor <int64_t , 2 >();
62
86
auto ans_a = ans.accessor <scalar_t , 1 >();
63
87
88
+ int t_offset = (modified ? -1 : 0 );
64
89
for (int b = 0 ; b < B; b++) {
65
90
int s_begin = boundary_a[b][0 ];
66
91
int t_begin = boundary_a[b][1 ];
67
92
int s_end = boundary_a[b][2 ];
68
93
int t_end = boundary_a[b][3 ];
69
94
p_a[b][s_begin][t_begin] = 0.0 ;
70
- for (int s = s_begin + 1 ; s <= s_end; ++s)
71
- p_a[b][s][t_begin] =
72
- p_a[b][s - 1 ][t_begin] + px_a[b][s - 1 ][t_begin];
95
+ if (modified) {
96
+ for (int s = s_begin + 1 ; s <= s_end; ++s)
97
+ p_a[b][s][t_begin] = -std::numeric_limits<scalar_t >::infinity ();
98
+ } else {
99
+ // note: t_offset = 0 so don't need t_begin + t_offset below.
100
+ for (int s = s_begin + 1 ; s <= s_end; ++s)
101
+ p_a[b][s][t_begin] =
102
+ p_a[b][s - 1 ][t_begin] + px_a[b][s - 1 ][t_begin];
103
+ }
73
104
for (int t = t_begin + 1 ; t <= t_end; ++t)
74
105
p_a[b][s_begin][t] =
75
106
p_a[b][s_begin][t - 1 ] + py_a[b][s_begin][t - 1 ];
76
107
for (int s = s_begin + 1 ; s <= s_end; ++s) {
77
108
scalar_t p_s_t1 = p_a[b][s][t_begin];
78
109
for (int t = t_begin + 1 ; t <= t_end; ++t) {
79
110
// The following statement is a small optimization of:
80
- // p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
81
- // p_a[b][s][t - 1] + py_a[b][s][t - 1]);
111
+ // p_a[b][s][t] = LogAdd(
112
+ // p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset],
113
+ // p_a[b][s][t - 1] + py_a[b][s][t - 1]);
82
114
// .. which obtains p_a[b][s][t - 1] from a register.
83
- p_a[b][s][t] = p_s_t1 =
84
- LogAdd< scalar_t >()( p_a[b][s - 1 ][t] + px_a[b][s - 1 ][t],
85
- p_s_t1 + py_a[b][s][t - 1 ]);
115
+ p_a[b][s][t] = p_s_t1 = LogAdd< scalar_t >()(
116
+ p_a[b][s - 1 ][t + t_offset ] + px_a[b][s - 1 ][t + t_offset ],
117
+ p_s_t1 + py_a[b][s][t - 1 ]);
86
118
}
87
119
}
88
120
ans_a[b] = p_a[b][s_end][t_end];
@@ -102,15 +134,18 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
102
134
TORCH_CHECK (p.dim () == 3 , " p must be 3-dimensional." );
103
135
TORCH_CHECK (ans_grad.dim () == 1 , " ans_grad must be 1-dimensional." );
104
136
137
+ bool modified = (px.size (2 ) == py.size (2 ));
138
+
105
139
TORCH_CHECK (px.device ().is_cpu () && py.device ().is_cpu () &&
106
140
p.device ().is_cpu () && ans_grad.device ().is_cpu (),
107
141
" inputs must be CPU tensors" );
108
142
109
143
auto scalar_t = px.scalar_type ();
110
144
auto opts = torch::TensorOptions ().dtype (scalar_t ).device (px.device ());
111
145
112
- const int B = px.size (0 ), S = px.size (1 ), T = px.size (2 ) - 1 ;
113
- TORCH_CHECK (py.size (0 ) == B && py.size (1 ) == S + 1 && py.size (2 ) == T);
146
+ const int B = px.size (0 ), S = px.size (1 ), T = py.size (2 );
147
+ TORCH_CHECK (px.size (2 ) == (modified ? T : T + 1 ));
148
+ TORCH_CHECK (py.size (0 ) == B && py.size (1 ) == S + 1 );
114
149
TORCH_CHECK (p.size (0 ) == B && p.size (1 ) == S + 1 && p.size (2 ) == T + 1 );
115
150
116
151
auto boundary = opt_boundary.value_or (
@@ -123,9 +158,10 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
123
158
TORCH_CHECK (boundary.device ().is_cpu () && boundary.dtype () == torch::kInt64 );
124
159
125
160
bool has_boundary = opt_boundary.has_value ();
161
+ int T1 = T + (modified ? 0 : 1 );
126
162
torch::Tensor p_grad = torch::zeros ({B, S + 1 , T + 1 }, opts),
127
- px_grad = (has_boundary ? torch::zeros ({B, S, T + 1 }, opts)
128
- : torch::empty ({B, S, T + 1 }, opts)),
163
+ px_grad = (has_boundary ? torch::zeros ({B, S, T1 }, opts)
164
+ : torch::empty ({B, S, T1 }, opts)),
129
165
py_grad = (has_boundary ? torch::zeros ({B, S + 1 , T}, opts)
130
166
: torch::empty ({B, S + 1 , T}, opts));
131
167
@@ -138,6 +174,7 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
138
174
139
175
auto ans_grad_a = ans_grad.accessor <scalar_t , 1 >();
140
176
auto boundary_a = boundary.accessor <int64_t , 2 >();
177
+ int t_offset = (modified ? -1 : 0 );
141
178
142
179
for (int b = 0 ; b < B; b++) {
143
180
int s_begin = boundary_a[b][0 ];
@@ -151,10 +188,12 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
151
188
for (int t = t_end; t > t_begin; --t) {
152
189
// The s,t indexes correspond to
153
190
// The statement we are backpropagating here is:
154
- // p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
155
- // p_a[b][s][t - 1] + py_a[b][s][t - 1]);
191
+ // p_a[b][s][t] = LogAdd(
192
+ // p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
193
+ // p_a[b][s][t - 1] + py_a[b][s][t - 1]);
156
194
// .. which obtains p_a[b][s][t - 1] from a register.
157
- scalar_t term1 = p_a[b][s - 1 ][t] + px_a[b][s - 1 ][t],
195
+ scalar_t term1 = p_a[b][s - 1 ][t + t_offset] +
196
+ px_a[b][s - 1 ][t + t_offset],
158
197
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
159
198
// actually needed..
160
199
total = p_a[b][s][t];
@@ -170,8 +209,8 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
170
209
// could happen if total == -inf
171
210
term1_grad = term2_grad = 0.0 ;
172
211
}
173
- px_grad_a[b][s - 1 ][t] = term1_grad;
174
- p_grad_a[b][s - 1 ][t] = term1_grad;
212
+ px_grad_a[b][s - 1 ][t + t_offset ] = term1_grad;
213
+ p_grad_a[b][s - 1 ][t + t_offset ] = term1_grad;
175
214
py_grad_a[b][s][t - 1 ] = term2_grad;
176
215
p_grad_a[b][s][t - 1 ] += term2_grad;
177
216
}
@@ -184,14 +223,17 @@ std::vector<torch::Tensor> MutualInformationBackwardCpu(
184
223
p_grad_a[b][s_begin][t - 1 ] += this_p_grad;
185
224
py_grad_a[b][s_begin][t - 1 ] = this_p_grad;
186
225
}
187
- for (int s = s_end; s > s_begin; --s) {
188
- // Backprop for:
189
- // p_a[b][s][t_begin] =
190
- // p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
191
- scalar_t this_p_grad = p_grad_a[b][s][t_begin];
192
- p_grad_a[b][s - 1 ][t_begin] += this_p_grad;
193
- px_grad_a[b][s - 1 ][t_begin] = this_p_grad;
194
- }
226
+ if (!modified) {
227
+ for (int s = s_end; s > s_begin; --s) {
228
+ // Backprop for:
229
+ // p_a[b][s][t_begin] =
230
+ // p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
231
+ scalar_t this_p_grad = p_grad_a[b][s][t_begin];
232
+ p_grad_a[b][s - 1 ][t_begin] += this_p_grad;
233
+ px_grad_a[b][s - 1 ][t_begin] = this_p_grad;
234
+ }
235
+ } // else these were all -infinity's and there is nothing to
236
+ // backprop.
195
237
// There is no backprop for:
196
238
// p_a[b][s_begin][t_begin] = 0.0;
197
239
// .. but we can use this for a check, that the grad at the beginning
0 commit comments