From 6675344f0e3d2ea4021009049e8f29ce92a865f2 Mon Sep 17 00:00:00 2001 From: Jonatas Grosman Date: Wed, 15 Feb 2023 20:22:58 -0300 Subject: [PATCH] fix bug in reshaping labels --- src/transformers/models/whisper/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 0313613b70e7..a9cbe50ce90a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1211,7 +1211,7 @@ def forward( loss = None if labels is not None: loss_fct = CrossEntropyLoss() - loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) if not return_dict: output = (lm_logits,) + outputs[1:]