Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
katalinic-gc committed Jun 7, 2023
1 parent a06f5ff commit 9a7f19c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
8 changes: 4 additions & 4 deletions optimum/graphcore/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,13 @@ def _get_generation_step_tensor(self, generation_step, ascending=False):
return per_replica.repeat(decoder_ipu_config.inference_replication_factor)

def _populate_parallelize_kwargs_with_generation_config(self, **kwargs):
if not kwargs.get("use_cache", False) or self.generation_config is None:
if self.generation_config is None:
return kwargs

for kwarg in ["num_beams, max_length"]:
for kwarg in ["num_beams", "max_length"]:
if kwarg not in kwargs:
kwarg_value = self.generation_config.kwarg
kwargs["kwarg"] = kwarg_value
kwarg_value = getattr(self.generation_config, kwarg)
kwargs[kwarg] = kwarg_value
logger.info(f"Setting parallelize kwarg `{kwarg}` to value in generation_config ({kwarg_value}).")

return kwargs
Expand Down
3 changes: 2 additions & 1 deletion optimum/graphcore/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,8 @@ def parallelize(self, for_generation=False, use_cache=False, **kwargs):
"""
super().parallelize()

kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)
if use_cache:
kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)

logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
Expand Down
3 changes: 2 additions & 1 deletion optimum/graphcore/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def change_attention_class(self, restore=False, **kwargs):
def parallelize(self, for_generation=False, use_cache=False, use_cross_cache=False, **kwargs):
super().parallelize()

kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)
if use_cache:
kwargs = self._populate_parallelize_kwargs_with_generation_config(**kwargs)

self.change_encoder_layer_class(restore=False)
self.change_decoder_class(restore=False)
Expand Down

0 comments on commit 9a7f19c

Please sign in to comment.