Skip to content

Commit 9832893

Browse files
authored
Resolve comments. (#871)
* Resolve comments. * Minor fixes.
1 parent 83faafa commit 9832893

File tree

5 files changed

+28
-49
lines changed

5 files changed

+28
-49
lines changed

k2/torch/bin/attention_rescore.cu

+1-42
Original file line numberDiff line numberDiff line change
@@ -281,55 +281,14 @@ int main(int argc, char *argv[]) {
281281
std::vector<std::vector<int32_t>> token_ids = tokens.ToVecVec();
282282
// convert std::vector<std::vector<int32_t>>
283283
// to
284-
// torch::List<torch::IValue> where torch::IValue is torch::List<int32_t>
285-
#if 0
286-
// clang-format off
287-
//
288-
// This branch is used when `token_ids` is of type List[List[int]],
289-
// but it throws the following error during runtime:
290-
//
291-
/*
292-
terminate called after throwing an instance of 'std::runtime_error'
293-
what(): The following operation failed in the TorchScript interpreter.
294-
Traceback of TorchScript, serialized code (most recent call last):
295-
File "code/__torch__/conformer.py", line 97, in decoder_nll
296-
for _22 in range(torch.len(ys_out)):
297-
y1 = ys_out[_22]
298-
_23 = torch.tensor(y1, dtype=None, device=None, requires_grad=False)
299-
~~~~~~~~~~~~ <--- HERE
300-
_24 = torch.append(ys_out1, _23)
301-
ys_out_pad = _15(ys_out1, True, -1., )
302-
303-
Traceback of TorchScript, original code (most recent call last):
304-
File "/ceph-fj/open-source/icefall-torchscript/egs/librispeech/ASR/conformer_ctc/transformer.py", line 340, in decoder_nll
305-
306-
ys_out = add_eos(token_ids, eos_id=eos_id)
307-
ys_out = [torch.tensor(y) for y in ys_out]
308-
~~~~~~~~~~~~ <--- HERE
309-
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
310-
RuntimeError: Add new condition, expected Float, Complex, Int, or Bool but gotint
311-
*/
312-
torch::List<torch::IValue> token_ids_list(torch::ListType::ofInts());
313-
314-
token_ids_list.reserve(token_ids.size());
315-
for (const auto tids : token_ids) {
316-
torch::List<torch::IValue> tmp(torch::IntType::create());
317-
tmp.reserve(tids.size());
318-
for (auto i : tids) {
319-
tmp.emplace_back(torch::IValue(i));
320-
}
321-
token_ids_list.push_back(std::move(tmp));
322-
}
323-
// clang-format on
324-
#else
284+
// torch::List<torch::IValue> where torch::IValue is torch::Tensor
325285
torch::List<torch::IValue> token_ids_list(torch::TensorType::get());
326286

327287
token_ids_list.reserve(token_ids.size());
328288
for (const auto tids : token_ids) {
329289
torch::Tensor tids_tensor = torch::tensor(tids);
330290
token_ids_list.emplace_back(tids_tensor);
331291
}
332-
#endif
333292

334293
K2_LOG(INFO) << "Run attention decoder";
335294
torch::Tensor nll =

k2/torch/csrc/fsa_algo.cu

+15
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,19 @@ Nbest RandomPaths(FsaClass &lattice, int32_t num_paths) {
237237
return {ans_lattice, utt_to_path_shape};
238238
}
239239

240+
FsaClass IntersectDevice(FsaClass &a_fsas, FsaClass &b_fsas,
241+
const Array1<int32_t> &b_to_a_map,
242+
bool sorted_match_a) {
243+
Array1<int32_t> arc_map_a, arc_map_b;
244+
245+
Fsa c_fsas = IntersectDevice(a_fsas.fsa, a_fsas.Properties(), b_fsas.fsa,
246+
b_fsas.Properties(), b_to_a_map, &arc_map_a,
247+
&arc_map_b, sorted_match_a);
248+
249+
FsaClass ans(c_fsas);
250+
ans.CopyAttrs(a_fsas, Array1ToTorch(arc_map_a));
251+
ans.CopyAttrs(b_fsas, Array1ToTorch(arc_map_b));
252+
return ans;
253+
}
254+
240255
} // namespace k2

k2/torch/csrc/fsa_algo.h

+6
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ void TopSort(FsaClass *lattice);
126126
*/
127127
Nbest RandomPaths(FsaClass &lattice, int32_t num_paths);
128128

129+
/// Wrapper for k2::IntersectDevice() in k2/csrc/fsa_algo.h
130+
/// to support attribute propagation.
131+
FsaClass IntersectDevice(FsaClass &a_fsas, FsaClass &b_fsas,
132+
const Array1<int32_t> &b_to_a_map,
133+
bool sorted_match_a);
134+
129135
} // namespace k2
130136

131137
#endif // K2_TORCH_CSRC_FSA_ALGO_H_

k2/torch/csrc/nbest.cu

+5-7
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,13 @@ void Nbest::Intersect(FsaClass *lattice) {
6666
// Now lattice has word IDs as labels and token IDs as aux_labels
6767
ArcSort(lattice);
6868

69-
Array1<int32_t> arc_map_a, arc_map_b;
69+
FsaClass word_fsa_with_epsilon_self_loops_wrapper(
70+
word_fsa_with_epsilon_self_loops);
7071

71-
Fsa path_lattice = k2::IntersectDevice(
72-
lattice->fsa, lattice->Properties(), word_fsa_with_epsilon_self_loops,
73-
FsaClass(word_fsa_with_epsilon_self_loops).Properties(), path_to_utt_map,
74-
&arc_map_a, &arc_map_b, true);
72+
FsaClass ans =
73+
IntersectDevice(*lattice, word_fsa_with_epsilon_self_loops_wrapper,
74+
path_to_utt_map, true);
7575

76-
FsaClass ans(path_lattice);
77-
ans.CopyAttrs(*lattice, k2::Array1ToTorch(arc_map_a));
7876
Connect(&ans);
7977
TopSort(&ans);
8078
ans = ShortestPath(ans);

k2/torch/csrc/nbest.h

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ struct Nbest {
7878
/// `this` nbest.
7979
///
8080
/// @param lattice The lattice to intersect. Note it is modified in-place.
81+
/// You should not use it after invoking this function.
8182
///
8283
/// Note: The scores for the return value of FromLattice() are
8384
/// all 0s.

0 commit comments

Comments
 (0)