Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Zipformer from Dan #672

Merged
merged 669 commits into from
Nov 12, 2022
Merged

Add Zipformer from Dan #672

merged 669 commits into from
Nov 12, 2022

Conversation

csukuangfj
Copy link
Collaborator

@csukuangfj csukuangfj commented Nov 11, 2022

All the changes in this PR are from @danpovey

Things to take:
(1) The model is trained using only LibriSpeech and the number of parameters is about 70.37 M
(2) We can use a much larger max_duration, i.e., 750
(3) We use half-precision during training
(4) The model converges much faster and yields the best WER we have on LibriSpeech when not using extra data from GigaSpeech

Here are the results:

decoding method test-clean test-other comment
greedy search 2.17 5.23 --epoch 30, --avg 9
modified_beam_search 2.15 5.2 --epoch 30, --avg 9
fast_beam_search 2.15 5.22 --epoch 30, --avg 9

(Hint: You can find the results of the Conformer paper at https://arxiv.org/pdf/2005.08100.pdf)

Training command:

export CUDA_VISIBLE_DEVICES="0,3,6,7"

./pruned_transducer_stateless7/train.py \
  --world-size 4 \
  --num-epochs 30 \
  --full-libri 1 \
  --use-fp16 1 \
  --max-duration 750 \
  --exp-dir pruned_transducer_stateless7/exp \
  --feedforward-dims  "1024,1024,2048,2048,1024" \
  --master-port 12535

Things to note:

  1. It uses 4 V100 GPUs (32 GB RAM)
  2. It uses float16 in training, i.e., half-precision
  3. The max duration is 750

To give you an idea of the training time per epoch:

(py38) kuangfangjun:exp$ ls -lhrt epoch-* | tail
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 01:41 epoch-21.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 03:04 epoch-22.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 04:26 epoch-23.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 05:48 epoch-24.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 07:10 epoch-25.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 08:32 epoch-26.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 09:55 epoch-27.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 11:17 epoch-28.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 12:40 epoch-29.pt
-rw-r--r-- 1 kuangfangjun root 1.1G Nov  8 14:02 epoch-30.pt

It is about 1 hour and 20 minutes per epoch.

The number of model parameters:

2022-11-06 21:04:40,836 INFO [train.py:984] (0/4) Number of model parameters: 70369391

That is, there are about 70.37 M parameters.


Decoding commands

(Note: I only list --epoch 30, --avg 9 but I have searched almost all combinations of --epoch --avg and
--epoch 30 --avg 9 is the best one)

for m in greedy_search fast_beam_search modified_beam_search ; do
  for epoch in 30; do
    for avg in 9; do
      ./pruned_transducer_stateless7/decode.py \
          --epoch $epoch \
          --avg $avg \
          --use-averaged-model 1 \
          --exp-dir ./pruned_transducer_stateless7/exp \
          --feedforward-dims  "1024,1024,2048,2048,1024" \
          --max-duration 600 \
          --decoding-method $m
    done
  done
done

To give you an idea about how the model performs in the early epochs:

(1) Validation loss

(py38) kuangfangjun:log$ grep -r -n --color "validation:" log-train-2022-11-06-21-04-40-0  | head -n 15
32:2022-11-06 21:07:39,917 INFO [train.py:923] (0/4) Epoch 1, validation: loss=6.934, simple_loss=6.259, pruned_loss=6.727, over 944034.00 frames.
385:2022-11-06 22:01:54,263 INFO [train.py:923] (0/4) Epoch 1, validation: loss=0.2959, simple_loss=0.3905, pruned_loss=0.1006, over 944034.00 frames.
471:2022-11-06 22:13:30,056 INFO [train.py:923] (0/4) Epoch 2, validation: loss=0.2658, simple_loss=0.3605, pruned_loss=0.08557, over 944034.00 frames.
869:2022-11-06 23:25:34,858 INFO [train.py:923] (0/4) Epoch 2, validation: loss=0.2223, simple_loss=0.3212, pruned_loss=0.06165, over 944034.00 frames.
952:2022-11-06 23:37:58,548 INFO [train.py:923] (0/4) Epoch 3, validation: loss=0.2207, simple_loss=0.3204, pruned_loss=0.06049, over 944034.00 frames.
1373:2022-11-07 00:47:49,999 INFO [train.py:923] (0/4) Epoch 3, validation: loss=0.2036, simple_loss=0.3052, pruned_loss=0.05099, over 944034.00 frames.
1454:2022-11-07 01:00:08,376 INFO [train.py:923] (0/4) Epoch 4, validation: loss=0.2057, simple_loss=0.3072, pruned_loss=0.05203, over 944034.00 frames.
1875:2022-11-07 02:10:04,402 INFO [train.py:923] (0/4) Epoch 4, validation: loss=0.192, simple_loss=0.2944, pruned_loss=0.04475, over 944034.00 frames.
1929:2022-11-07 02:22:25,695 INFO [train.py:923] (0/4) Epoch 5, validation: loss=0.1906, simple_loss=0.2936, pruned_loss=0.04381, over 944034.00 frames.
2304:2022-11-07 03:32:20,114 INFO [train.py:923] (0/4) Epoch 5, validation: loss=0.1814, simple_loss=0.2851, pruned_loss=0.03888, over 944034.00 frames.
2377:2022-11-07 03:44:36,352 INFO [train.py:923] (0/4) Epoch 6, validation: loss=0.1825, simple_loss=0.286, pruned_loss=0.03951, over 944034.00 frames.
2839:2022-11-07 04:54:32,569 INFO [train.py:923] (0/4) Epoch 6, validation: loss=0.1757, simple_loss=0.279, pruned_loss=0.03619, over 944034.00 frames.
2908:2022-11-07 05:06:51,631 INFO [train.py:923] (0/4) Epoch 7, validation: loss=0.1765, simple_loss=0.2801, pruned_loss=0.03646, over 944034.00 frames.
3300:2022-11-07 06:16:47,922 INFO [train.py:923] (0/4) Epoch 7, validation: loss=0.1706, simple_loss=0.2744, pruned_loss=0.03339, over 944034.00 frames.
3376:2022-11-07 06:29:05,468 INFO [train.py:923] (0/4) Epoch 8, validation: loss=0.1708, simple_loss=0.2745, pruned_loss=0.03354, over 944034.00 frames.
(py38) kuangfangjun:log$ grep -r -n --color "validation:" log-train-2022-11-06-21-04-40-0  | tail -n 15
10947:2022-11-08 04:14:30,541 INFO [train.py:923] (0/4) Epoch 23, validation: loss=0.1505, simple_loss=0.2502, pruned_loss=0.02537, over 944034.00 frames.
11030:2022-11-08 04:26:46,624 INFO [train.py:923] (0/4) Epoch 24, validation: loss=0.1498, simple_loss=0.25, pruned_loss=0.02478, over 944034.00 frames.
11477:2022-11-08 05:36:44,939 INFO [train.py:923] (0/4) Epoch 24, validation: loss=0.1492, simple_loss=0.2489, pruned_loss=0.02477, over 944034.00 frames.
11544:2022-11-08 05:48:59,527 INFO [train.py:923] (0/4) Epoch 25, validation: loss=0.15, simple_loss=0.2496, pruned_loss=0.02523, over 944034.00 frames.
12012:2022-11-08 06:58:43,479 INFO [train.py:923] (0/4) Epoch 25, validation: loss=0.1494, simple_loss=0.249, pruned_loss=0.02494, over 944034.00 frames.
12074:2022-11-08 07:11:04,005 INFO [train.py:923] (0/4) Epoch 26, validation: loss=0.1498, simple_loss=0.2493, pruned_loss=0.02517, over 944034.00 frames.
12515:2022-11-08 08:21:01,368 INFO [train.py:923] (0/4) Epoch 26, validation: loss=0.1497, simple_loss=0.2488, pruned_loss=0.02532, over 944034.00 frames.
12585:2022-11-08 08:33:23,160 INFO [train.py:923] (0/4) Epoch 27, validation: loss=0.1494, simple_loss=0.2488, pruned_loss=0.02496, over 944034.00 frames.
12984:2022-11-08 09:43:34,231 INFO [train.py:923] (0/4) Epoch 27, validation: loss=0.1491, simple_loss=0.2482, pruned_loss=0.02496, over 944034.00 frames.
13049:2022-11-08 09:55:52,480 INFO [train.py:923] (0/4) Epoch 28, validation: loss=0.1497, simple_loss=0.2489, pruned_loss=0.02525, over 944034.00 frames.
13464:2022-11-08 11:05:46,453 INFO [train.py:923] (0/4) Epoch 28, validation: loss=0.149, simple_loss=0.2476, pruned_loss=0.0252, over 944034.00 frames.
13545:2022-11-08 11:18:14,997 INFO [train.py:923] (0/4) Epoch 29, validation: loss=0.1495, simple_loss=0.2482, pruned_loss=0.02537, over 944034.00 frames.
13943:2022-11-08 12:28:17,173 INFO [train.py:923] (0/4) Epoch 29, validation: loss=0.1491, simple_loss=0.2478, pruned_loss=0.02521, over 944034.00 frames.
14008:2022-11-08 12:40:32,335 INFO [train.py:923] (0/4) Epoch 30, validation: loss=0.1502, simple_loss=0.2487, pruned_loss=0.02586, over 944034.00 frames.
14423:2022-11-08 13:50:36,196 INFO [train.py:923] (0/4) Epoch 30, validation: loss=0.1487, simple_loss=0.2472, pruned_loss=0.02508, over 944034.00 frames.

(2) WERs for earlier epochs (greedy search)

test-clean test-other comment
6.54 14.48 epoch 2, avg 1
4.78 11.37 epoch 3, avg 1
4.16 9.74 epoch 4, avg 1
3.74 8.84 epoch 5, avg 1
3.44 8.24 epoch 6, avg 1
3.16 7.73 epoch 7, avg 1
3.05 7.18 epoch 8, avg 1
2.87 7.02 epoch 9, avg 1
2.84 6.62 epoch 10, avg 1
2.71 6.53 epoch 11, avg 1
2.66 6.42 epoch 11, avg 2
2.64 6.32 epoch 12, avg 1
2.61 6.24 epoch 12, avg 2
2.56 6.29 epoch 13, avg 1
2.55 6.14 epoch 13, avg 2
2.58 6.14 epoch 14, avg 1
2.54 6.07 epoch 14, avg 2
2.48 6.03 epoch 15, avg 1
2.46 5.94 epoch 15, avg 2
2.49 5.90 epoch 16, avg 1
2.47 5.84 epoch 16, avg 2
2.50 5.89 epoch 17, avg 1
2.42 5.71 epoch 17, avg 2
2.40 5.99 epoch 18, avg 1
2.36 5.81 epoch 18, avg 2
2.44 5.86 epoch 19, avg 1
2.36 5.84 epoch 19, avg 2
2.39 5.84 epoch 20, avg 1
2.37 5.72 epoch 20, avg 2

I am uploading the pre-trained models, tensorboard logs and decoding results to hugging face.

@danpovey
Copy link
Collaborator

Remember to squash when you merge (657 commits.. will bulk up the repo size.)

@csukuangfj
Copy link
Collaborator Author

Remember to squash when you merge (657 commits.. will bulk up the repo size.)

Thanks. I will.


I am making it support torchscript.


class ZipformerEncoderLayer(nn.Module):
"""
ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs should be updated.

@csukuangfj
Copy link
Collaborator Author

The tensorboard log is available at
https://tensorboard.dev/experiment/P7vXWqK7QVu1mU9Ene1gGg/

Screen Shot 2022-11-11 at 9 09 32 PM

@csukuangfj
Copy link
Collaborator Author

Here is the comparison between zipformer and our previous reworked conformer.
(The results for reworked conformer are from https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md)

Note: I only list our previous best results without LM rescoring and without trained with extra-data.

name test-clean test-other comments num parameters (M)
pruned_transducer_stateless7 (this PR) 2.46 5.94 greedy search, epoch 15, avg 2 70.37
pruned_transducer_stateless7 (this PR) 2.17 5.23 greedy search, epoch 30, avg 9 70.37
pruned_transducer_stateless7 (this PR) 2.15 5.20 modified beam search, epoch 30, avg 9 70.37
pruned_transducer_stateless7 (this PR) 2.15 5.22 fast beam search, epoch 30, avg 9 70.37
pruned_transducer_stateless5 (large) 2.43 5.72 greedy search, epoch 30, avg 10 118.13
pruned_transducer_stateless5 (large) 2.43 5.69 modified beam search, epoch 30, avg 10 118.13
pruned_transducer_stateless5 (large) 2.43 5.67 fast beam search, epoch 30, avg 10 118.13
pruned_transducer_stateless5 2.54 5.72 greedy search, epoch 30, avg 10 87.8
pruned_transducer_stateless5 2.47 5.71 modified beam search, epoch 30, avg 10 87.8
pruned_transducer_stateless5 2.50 5.72 fast beam search, epoch 30, avg 10 87.8

@csukuangfj
Copy link
Collaborator Author

Will merge it after the CI passes.

@csukuangfj csukuangfj merged commit 7e82f87 into master Nov 12, 2022
@csukuangfj csukuangfj deleted the from-dan-scaled-adam-exp253 branch November 13, 2022 02:16
@csukuangfj
Copy link
Collaborator Author

You can try the pre-trained model of this PR from your browser without installing anything.

Just go to https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition and record your voice for recognition.

Screen Shot 2022-11-13 at 09 59 10 copy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants