Skip to content

Commit

Permalink
Efficient decoder text generation wrapper (huggingface#273)
Browse files Browse the repository at this point in the history
* Set the default matmul_proportion in IPUConfig to 0.2 so default config will work with decoder wrapper
  • Loading branch information
jimypbr authored and ncouro-gc committed Mar 17, 2023
1 parent 1a5579f commit 609c870
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 33 deletions.
26 changes: 4 additions & 22 deletions optimum/graphcore/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def beam_search(
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -545,13 +546,6 @@ def beam_search(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down Expand Up @@ -840,6 +834,7 @@ def sample(

# forward pass to get next token
outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -850,13 +845,6 @@ def sample(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down Expand Up @@ -1119,6 +1107,7 @@ def beam_sample(
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -1129,13 +1118,6 @@ def beam_sample(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down Expand Up @@ -1253,4 +1235,4 @@ def beam_sample(
hidden_states=decoder_hidden_states,
)
else:
return sequence_outputs["sequences"]
return sequence_outputs["sequences"]
2 changes: 1 addition & 1 deletion optimum/graphcore/ipu_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, **kwargs):
'The "sharded_execution_for_inference" parameter is deprecated, sharded execution is always used during inference'
)

self.matmul_proportion = kwargs.pop("matmul_proportion", 0.6)
self.matmul_proportion = kwargs.pop("matmul_proportion", 0.2)

if "enable_half_first_order_momentum" in kwargs:
warnings.warn('The "enable_half_first_order_momentum" parameter is deprecated')
Expand Down
20 changes: 10 additions & 10 deletions optimum/graphcore/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForCausalLM,),
"default": {
"model": ("gpt2", "e7da7f2"),
"ipu_config": "Graphcore/gpt2-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
},
"type": "text",
Expand All @@ -148,7 +148,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("ainize/bart-base-cnn", "b90bc9a"),
"ipu_config": "Graphcore/bart-base-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_input_length": 50,
"max_length": 20,
"truncation": "only_first",
Expand All @@ -161,7 +161,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("t5-small", "9507060"),
"ipu_config": "Graphcore/t5-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
"max_input_length": 45,
"truncation": "only_first",
Expand All @@ -173,7 +173,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("t5-small", "9507060"),
"ipu_config": "Graphcore/t5-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
"max_input_length": 50,
"truncation": "only_first",
Expand Down Expand Up @@ -210,7 +210,7 @@ def list_tasks() -> List[str]:

def get_poplar_executor(
model: PreTrainedModel,
ipu_config: Union[str, dict] = None,
ipu_config: Union[IPUConfig, str, dict] = None,
fp16: bool = True,
) -> PreTrainedModel:
ipu_config_arg = ipu_config
Expand All @@ -219,8 +219,8 @@ def get_poplar_executor(
ipu_config = IPUConfig.from_pretrained(ipu_config)
elif isinstance(ipu_config, dict):
ipu_config = IPUConfig.from_dict(ipu_config)
else:
raise ValueError("ipu_config must be a string or a dictionary.")
elif not isinstance(ipu_config, IPUConfig):
raise ValueError("ipu_config must be an IPUConfig, string, or a dictionary.")
ipu_config.inference_device_iterations = 1
# TODO: inference_replication_factor should be adaptive, especially for batching.
ipu_config.inference_replication_factor = 1
Expand Down Expand Up @@ -280,7 +280,7 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
def pipeline(
task: str = None,
model: Optional[Any] = None,
ipu_config: Union[str, dict] = None,
ipu_config: Union[IPUConfig, str, dict] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
revision: Optional[str] = None,
Expand Down Expand Up @@ -411,8 +411,8 @@ def new_forward(self, model_inputs, *args, **kwargs):
# Implement pipelines __del__ to clean up poplar exector
def _del(self):
# For text generation models, deallocate the internal poplar executor
if hasattr(self.model, "poptorch_model"):
self.model.poptorch_model.destroy()
if hasattr(self.model, "poptorch_decoder"):
self.model.poptorch_decoder.destroy()

pipeline_class.__del__ = _del

Expand Down

0 comments on commit 609c870

Please sign in to comment.