From a2a79c7992f8ef45e1843d419236c5b0b3eb5ac8 Mon Sep 17 00:00:00 2001 From: suxia Date: Wed, 10 Aug 2022 17:59:31 +0800 Subject: [PATCH 1/4] Add function display_and_save_batch in egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py --- .../ASR/pruned_transducer_stateless2/train.py | 73 ++++++++++++++----- 1 file changed, 54 insertions(+), 19 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index faf25eda1d..d6c4de58f0 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -701,25 +701,29 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params) + raise if params.print_diagnostics and batch_idx == 5: return @@ -957,6 +961,34 @@ def remove_short_and_long_utt(c: Cut): torch.distributed.barrier() cleanup_dist() +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + num_tokens = params.vocab_size + logging.info(f"num tokens: {num_tokens}") def scan_pessimistic_batches_for_oom( model: nn.Module, @@ -998,6 +1030,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) + display_and_save_batch(batch, params=params) raise @@ -1021,3 +1054,5 @@ def main(): if __name__ == "__main__": main() + + From 8d2078dba46c5317ad21c0a7d3598086a570d819 Mon Sep 17 00:00:00 2001 From: suxia Date: Thu, 11 Aug 2022 10:47:31 +0800 Subject: [PATCH 2/4] Modify function: display_and_save_batch --- .../ASR/pruned_transducer_stateless2/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index d6c4de58f0..8eafb7df7d 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -961,6 +961,7 @@ def remove_short_and_long_utt(c: Cut): torch.distributed.barrier() cleanup_dist() + def display_and_save_batch( batch: dict, params: AttributeDict, @@ -973,8 +974,6 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - sp: - The BPE model. """ from lhotse.utils import uuid4 @@ -982,14 +981,16 @@ def display_and_save_batch( logging.info(f"Saving batch to {filename}") torch.save(batch, filename) - supervisions = batch["supervisions"] features = batch["inputs"] logging.info(f"features shape: {features.shape}") - num_tokens = params.vocab_size + texts = batch["supervisions"]["text"] + num_tokens = sum(len(i) for i in texts) + logging.info(f"num tokens: {num_tokens}") + def scan_pessimistic_batches_for_oom( model: nn.Module, train_dl: torch.utils.data.DataLoader, From d21077c03df408d1c8538bd889b378e73c3224be Mon Sep 17 00:00:00 2001 From: suxia Date: Thu, 11 Aug 2022 11:00:39 +0800 Subject: [PATCH 3/4] Delete empty line in pruned_transducer_stateless2/train.py --- egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 8eafb7df7d..5ac888dccd 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -1055,5 +1055,3 @@ def main(): if __name__ == "__main__": main() - - From ecf165ede272421198dd2fa053a1fd0dab2586cb Mon Sep 17 00:00:00 2001 From: suxia Date: Fri, 12 Aug 2022 10:42:26 +0800 Subject: [PATCH 4/4] Modify code format --- egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 5ac888dccd..5208dbefe4 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -721,7 +721,7 @@ def train_one_epoch( scaler.step(optimizer) scaler.update() optimizer.zero_grad() - except: # noqa + except: # noqa display_and_save_batch(batch, params=params) raise