Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply delay penalty on transducer #654

Merged
merged 11 commits into from
Nov 4, 2022
8 changes: 8 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def forward(
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -108,6 +109,11 @@ def forward(
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Copy link
Collaborator

@ezerhouni ezerhouni Nov 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great paper ! One comment I notice, Fig 4a and 5.a seems to have the WER number wrong (since it is suppose to be test-clean which are in the 3-4 %), is that correct ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. The presented results in figures 3, 4, 5 are averaged over the test-clean and test-other subsets for comparison.

Returns:
Return the transducer loss.

Expand Down Expand Up @@ -164,6 +170,7 @@ def forward(
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True,
)

Expand Down Expand Up @@ -196,6 +203,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)

Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -611,6 +621,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
8 changes: 8 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def forward(
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -136,6 +137,11 @@ def forward(
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Return the transducer loss.

Expand Down Expand Up @@ -203,6 +209,7 @@ def forward(
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True,
)

Expand Down Expand Up @@ -235,6 +242,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)

Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,16 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -665,6 +675,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -623,6 +633,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
9 changes: 9 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def forward(
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -108,6 +109,12 @@ def forward(
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Returns:
Return the transducer loss.

Expand Down Expand Up @@ -164,6 +171,7 @@ def forward(
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True,
)

Expand Down Expand Up @@ -196,6 +204,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)

Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -607,6 +617,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
8 changes: 8 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def forward(
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -136,6 +137,11 @@ def forward(
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Return the transducer loss.

Expand Down Expand Up @@ -203,6 +209,7 @@ def forward(
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True,
)

Expand Down Expand Up @@ -235,6 +242,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)

Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,16 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)
return parser

Expand Down Expand Up @@ -645,6 +655,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -638,6 +648,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -662,6 +672,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down