@@ -281,55 +281,14 @@ int main(int argc, char *argv[]) {
281
281
std::vector<std::vector<int32_t >> token_ids = tokens.ToVecVec ();
282
282
// convert std::vector<std::vector<int32_t>>
283
283
// to
284
- // torch::List<torch::IValue> where torch::IValue is torch::List<int32_t>
285
- #if 0
286
- // clang-format off
287
- //
288
- // This branch is used when `token_ids` is of type List[List[int]],
289
- // but it throws the following error during runtime:
290
- //
291
- /*
292
- terminate called after throwing an instance of 'std::runtime_error'
293
- what(): The following operation failed in the TorchScript interpreter.
294
- Traceback of TorchScript, serialized code (most recent call last):
295
- File "code/__torch__/conformer.py", line 97, in decoder_nll
296
- for _22 in range(torch.len(ys_out)):
297
- y1 = ys_out[_22]
298
- _23 = torch.tensor(y1, dtype=None, device=None, requires_grad=False)
299
- ~~~~~~~~~~~~ <--- HERE
300
- _24 = torch.append(ys_out1, _23)
301
- ys_out_pad = _15(ys_out1, True, -1., )
302
-
303
- Traceback of TorchScript, original code (most recent call last):
304
- File "/ceph-fj/open-source/icefall-torchscript/egs/librispeech/ASR/conformer_ctc/transformer.py", line 340, in decoder_nll
305
-
306
- ys_out = add_eos(token_ids, eos_id=eos_id)
307
- ys_out = [torch.tensor(y) for y in ys_out]
308
- ~~~~~~~~~~~~ <--- HERE
309
- ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
310
- RuntimeError: Add new condition, expected Float, Complex, Int, or Bool but gotint
311
- */
312
- torch::List<torch::IValue> token_ids_list(torch::ListType::ofInts());
313
-
314
- token_ids_list.reserve(token_ids.size());
315
- for (const auto tids : token_ids) {
316
- torch::List<torch::IValue> tmp(torch::IntType::create());
317
- tmp.reserve(tids.size());
318
- for (auto i : tids) {
319
- tmp.emplace_back(torch::IValue(i));
320
- }
321
- token_ids_list.push_back(std::move(tmp));
322
- }
323
- // clang-format on
324
- #else
284
+ // torch::List<torch::IValue> where torch::IValue is torch::Tensor
325
285
torch::List<torch::IValue> token_ids_list (torch::TensorType::get ());
326
286
327
287
token_ids_list.reserve (token_ids.size ());
328
288
for (const auto tids : token_ids) {
329
289
torch::Tensor tids_tensor = torch::tensor (tids);
330
290
token_ids_list.emplace_back (tids_tensor);
331
291
}
332
- #endif
333
292
334
293
K2_LOG (INFO) << " Run attention decoder" ;
335
294
torch::Tensor nll =
0 commit comments