Skip to content

Commit

Permalink
Use index select for whisper position embedding for better tile utili…
Browse files Browse the repository at this point in the history
…zation (#435)
  • Loading branch information
katalinic-gc authored Jun 27, 2023
1 parent d492f27 commit 5ec9a78
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optimum/graphcore/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
if input_ids.shape[-1] == 1:
# KV cache enabled.
del past_key_values_length
return poptorch.dynamic_slice(self.weight, 0, self._generation_step, 1, 1)
return torch.index_select(self.weight, 0, self._generation_step)
else:
return super().forward(input_ids, past_key_values_length=past_key_values_length)

Expand Down

0 comments on commit 5ec9a78

Please sign in to comment.