Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for merge] Madam optimizer with OOM handling #8

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 21.6b0
rev: 21.7b0
hooks:
- id: black
args: [--line-length=80]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,10 @@ def __init__(
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
# NOTE(fangjun): The process hangs when using DDP
# if we try to recover from CUDA OOM, so we disable
# batchnorm layer here.
# self.norm = nn.BatchNorm1d(channels)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After disabling batch norm, training with DDP can now recover from OOM in model.forward() as expected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mm. Hopefully with the madam optimizer, the training will still be stable without the batchnorm. We'll have to see. Obviously would have to compare the performance after this change.

self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
Expand Down Expand Up @@ -899,7 +902,8 @@ def forward(self, x: Tensor) -> Tensor:

# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
# x = self.activation(self.norm(x))
x = self.activation(x)

x = self.pointwise_conv2(x) # (batch, channel, time)

Expand Down
115 changes: 47 additions & 68 deletions egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def get_params() -> AttributeDict:
"num_decoder_layers": 6,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"use_feat_batchnorm": False,
"lr_factor": 2.0,
"warm_step": 30000,
}
Expand Down Expand Up @@ -282,75 +282,59 @@ def compute_loss_impl(
assert feature.ndim == 3
feature = feature.to(device)

try:
supervisions = batch["supervisions"]

supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(
feature, supervisions
)
# nnet_output is [N, T, C]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]

# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)

token_ids = graph_compiler.texts_to_ids(texts)
token_ids = graph_compiler.texts_to_ids(texts)

decoding_graph = graph_compiler.compile(token_ids)
decoding_graph = graph_compiler.compile(token_ids)

dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)

ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)

if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (
1.0 - params.att_rate
) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
except RuntimeError as ex:
try:
del nnet_output
del encoder_memory
del dense_fsa_vec
del ctc_loss
del att_loss
del loss
except NameError as ne:
pass
raise ex
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])

# train_frames and valid_frames are used for printing.
if is_training:
Expand Down Expand Up @@ -394,11 +378,6 @@ def compute_loss(
s += f" max duration: {max_cut_duration:.3f} s \n"
logging.info(s)

# see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283
for p in model.parameters():
if p.grad is not None:
del p.grad # free some memory

torch.cuda.empty_cache()

gc.collect()
Expand Down