Skip to content

Commit

Permalink
Efficient decoder text generation wrapper (#273)
Browse files Browse the repository at this point in the history
* Set the default matmul_proportion in IPUConfig to 0.2 so default config will work with decoder wrapper
  • Loading branch information
jimypbr authored Mar 10, 2023
1 parent f0bc341 commit 66929fd
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 42 deletions.
66 changes: 35 additions & 31 deletions optimum/graphcore/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,44 @@
logger = logging.get_logger(__name__)


class DecoderWrapper(nn.Module):
"""
Fast wrapper for decoder part of text generation models.
Only returns the logits from the last generated token to reduce IO costs.
"""

def __init__(self, pipelined_model):
super().__init__()
self.pipelined_model = pipelined_model

def forward(self, t, **model_inputs):
"""
Args:
t : (`torch.Tensor(int)`) Tensor with single int representing the current length of the sequence being generated
model_inputs : Regular model_inputs passed to the wrapped model.
Returns:
The output logits at position `t` only
"""
outputs = self.pipelined_model(**model_inputs)

next_token_logits = poptorch.dynamic_slice(outputs.logits, 1, t, 1, 1)
return type(outputs)(
loss=None,
logits=next_token_logits,
)


class IPUGenerationMixin(GenerationMixin):
def _pad_tensors_to_max_len(self, tensor: torch.Tensor, max_length: int, pad_token_id: int) -> torch.Tensor:
return nn.functional.pad(tensor, (0, max_length - tensor.shape[1]), "constant", pad_token_id)

def _call_generate(self, *args, **kwargs):
if not hasattr(self, "poptorch_model"):
self.poptorch_model = poptorch.inferenceModel(self.eval(), self.ipu_config.to_options(for_inference=True))
if not hasattr(self, "poptorch_decoder"):
wrapper = DecoderWrapper(self.eval())
self.poptorch_decoder = poptorch.inferenceModel(wrapper, self.ipu_config.to_options(for_inference=True))

# This will trigger a compile first time it's ran
return self.poptorch_model(*args, **kwargs)
return self.poptorch_decoder(*args, **kwargs)

def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
Expand Down Expand Up @@ -245,6 +273,7 @@ def greedy_search(

# forward pass to get next token
outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -255,13 +284,6 @@ def greedy_search(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

next_token_logits = outputs.logits[:, -1, :]
Expand Down Expand Up @@ -516,6 +538,7 @@ def beam_search(
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -526,13 +549,6 @@ def beam_search(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down Expand Up @@ -820,6 +836,7 @@ def sample(

# forward pass to get next token
outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -830,13 +847,6 @@ def sample(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down Expand Up @@ -1098,6 +1108,7 @@ def beam_sample(
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -1108,13 +1119,6 @@ def beam_sample(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down
2 changes: 1 addition & 1 deletion optimum/graphcore/ipu_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, **kwargs):
'The "sharded_execution_for_inference" parameter is deprecated, sharded execution is always used during inference'
)

self.matmul_proportion = kwargs.pop("matmul_proportion", 0.6)
self.matmul_proportion = kwargs.pop("matmul_proportion", 0.2)

if "enable_half_first_order_momentum" in kwargs:
warnings.warn('The "enable_half_first_order_momentum" parameter is deprecated')
Expand Down
20 changes: 10 additions & 10 deletions optimum/graphcore/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForCausalLM,),
"default": {
"model": ("gpt2", "e7da7f2"),
"ipu_config": "Graphcore/gpt2-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
},
"type": "text",
Expand All @@ -148,7 +148,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("ainize/bart-base-cnn", "b90bc9a"),
"ipu_config": "Graphcore/bart-base-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_input_length": 50,
"max_length": 20,
"truncation": "only_first",
Expand All @@ -161,7 +161,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("t5-small", "9507060"),
"ipu_config": "Graphcore/t5-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
"max_input_length": 45,
"truncation": "only_first",
Expand All @@ -173,7 +173,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("t5-small", "9507060"),
"ipu_config": "Graphcore/t5-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
"max_input_length": 50,
"truncation": "only_first",
Expand Down Expand Up @@ -210,7 +210,7 @@ def list_tasks() -> List[str]:

def get_poplar_executor(
model: PreTrainedModel,
ipu_config: Union[str, dict] = None,
ipu_config: Union[IPUConfig, str, dict] = None,
fp16: bool = True,
) -> PreTrainedModel:
ipu_config_arg = ipu_config
Expand All @@ -219,8 +219,8 @@ def get_poplar_executor(
ipu_config = IPUConfig.from_pretrained(ipu_config)
elif isinstance(ipu_config, dict):
ipu_config = IPUConfig.from_dict(ipu_config)
else:
raise ValueError("ipu_config must be a string or a dictionary.")
elif not isinstance(ipu_config, IPUConfig):
raise ValueError("ipu_config must be an IPUConfig, string, or a dictionary.")
ipu_config.inference_device_iterations = 1
# TODO: inference_replication_factor should be adaptive, especially for batching.
ipu_config.inference_replication_factor = 1
Expand Down Expand Up @@ -280,7 +280,7 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
def pipeline(
task: str = None,
model: Optional[Any] = None,
ipu_config: Union[str, dict] = None,
ipu_config: Union[IPUConfig, str, dict] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
revision: Optional[str] = None,
Expand Down Expand Up @@ -411,8 +411,8 @@ def new_forward(self, model_inputs, *args, **kwargs):
# Implement pipelines __del__ to clean up poplar exector
def _del(self):
# For text generation models, deallocate the internal poplar executor
if hasattr(self.model, "poptorch_model"):
self.model.poptorch_model.destroy()
if hasattr(self.model, "poptorch_decoder"):
self.model.poptorch_decoder.destroy()

pipeline_class.__del__ = _del

Expand Down

0 comments on commit 66929fd

Please sign in to comment.