Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient decoder text generation wrapper #273

Merged
merged 10 commits into from
Mar 10, 2023
66 changes: 35 additions & 31 deletions optimum/graphcore/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,44 @@
logger = logging.get_logger(__name__)


class DecoderWrapper(nn.Module):
"""
Fast wrapper for decoder part of text generation models.
Only returns the logits from the last generated token to reduce IO costs.
"""

def __init__(self, pipelined_model):
super().__init__()
self.pipelined_model = pipelined_model

def forward(self, t, **model_inputs):
"""
Args:
t : (`torch.Tensor(int)`) Tensor with single int representing the current length of the sequence being generated
model_inputs : Regular model_inputs passed to the wrapped model.
Returns:
The output logits at position `t` only
"""
outputs = self.pipelined_model(**model_inputs)

next_token_logits = poptorch.dynamic_slice(outputs.logits, 1, t, 1, 1)
return type(outputs)(
loss=None,
logits=next_token_logits,
)


class IPUGenerationMixin(GenerationMixin):
def _pad_tensors_to_max_len(self, tensor: torch.Tensor, max_length: int, pad_token_id: int) -> torch.Tensor:
return nn.functional.pad(tensor, (0, max_length - tensor.shape[1]), "constant", pad_token_id)

def _call_generate(self, *args, **kwargs):
if not hasattr(self, "poptorch_model"):
self.poptorch_model = poptorch.inferenceModel(self.eval(), self.ipu_config.to_options(for_inference=True))
if not hasattr(self, "poptorch_decoder"):
wrapper = DecoderWrapper(self.eval())
self.poptorch_decoder = poptorch.inferenceModel(wrapper, self.ipu_config.to_options(for_inference=True))

# This will trigger a compile first time it's ran
return self.poptorch_model(*args, **kwargs)
return self.poptorch_decoder(*args, **kwargs)

def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
Expand Down Expand Up @@ -245,6 +273,7 @@ def greedy_search(

# forward pass to get next token
outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -255,13 +284,6 @@ def greedy_search(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

next_token_logits = outputs.logits[:, -1, :]
Expand Down Expand Up @@ -516,6 +538,7 @@ def beam_search(
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -526,13 +549,6 @@ def beam_search(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down Expand Up @@ -820,6 +836,7 @@ def sample(

# forward pass to get next token
outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -830,13 +847,6 @@ def sample(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down Expand Up @@ -1098,6 +1108,7 @@ def beam_sample(
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self._call_generate(
t=torch.tensor(cur_len - 1),
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
Expand All @@ -1108,13 +1119,6 @@ def beam_sample(
if not self.config.is_encoder_decoder:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :cur_len]

outputs.logits = outputs.logits[:, :cur_len, :]
if outputs.logits.dim() == 3:
outputs.logits = outputs.logits[:, :cur_len, :]
# If the dimension of logits is 2, then only the logits of the last non-padding token is returned, so no need to slice.
else:
next_token_logits = outputs.logits

# Change: remove synced_gpu code

# Change: cast to float on cpu
Expand Down
2 changes: 1 addition & 1 deletion optimum/graphcore/ipu_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, **kwargs):
'The "sharded_execution_for_inference" parameter is deprecated, sharded execution is always used during inference'
)

self.matmul_proportion = kwargs.pop("matmul_proportion", 0.6)
self.matmul_proportion = kwargs.pop("matmul_proportion", 0.2)

if "enable_half_first_order_momentum" in kwargs:
warnings.warn('The "enable_half_first_order_momentum" parameter is deprecated')
Expand Down
22 changes: 12 additions & 10 deletions optimum/graphcore/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForCausalLM,),
"default": {
"model": ("gpt2", "e7da7f2"),
"ipu_config": "Graphcore/gpt2-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
},
"type": "text",
Expand All @@ -148,7 +148,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("ainize/bart-base-cnn", "b90bc9a"),
"ipu_config": "Graphcore/bart-base-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_input_length": 50,
"max_length": 20,
"truncation": "only_first",
Expand All @@ -161,7 +161,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("t5-small", "9507060"),
"ipu_config": "Graphcore/t5-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
"max_input_length": 45,
"truncation": "only_first",
Expand All @@ -173,7 +173,7 @@ class IncompatibleIPUConfigError(Exception):
"class": (AutoModelForSeq2SeqLM,),
"default": {
"model": ("t5-small", "9507060"),
"ipu_config": "Graphcore/t5-small-ipu",
"ipu_config": IPUConfig(layers_per_ipu=[12], matmul_proportion=0.2),
"max_length": 50,
"max_input_length": 50,
"truncation": "only_first",
Expand Down Expand Up @@ -210,7 +210,7 @@ def list_tasks() -> List[str]:

def get_poplar_executor(
model: PreTrainedModel,
ipu_config: Union[str, dict] = None,
ipu_config: Union[IPUConfig, str, dict] = None,
fp16: bool = True,
) -> PreTrainedModel:
ipu_config_arg = ipu_config
Expand All @@ -219,8 +219,8 @@ def get_poplar_executor(
ipu_config = IPUConfig.from_pretrained(ipu_config)
elif isinstance(ipu_config, dict):
ipu_config = IPUConfig.from_dict(ipu_config)
else:
raise ValueError("ipu_config must be a string or a dictionary.")
elif not isinstance(ipu_config, IPUConfig):
raise ValueError("ipu_config must be an IPUConfig, string, or a dictionary.")
ipu_config.inference_device_iterations = 1
# TODO: inference_replication_factor should be adaptive, especially for batching.
ipu_config.inference_replication_factor = 1
Expand Down Expand Up @@ -280,7 +280,7 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
def pipeline(
task: str = None,
model: Optional[Any] = None,
ipu_config: Union[str, dict] = None,
ipu_config: Union[IPUConfig, str, dict] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
revision: Optional[str] = None,
Expand Down Expand Up @@ -411,8 +411,10 @@ def new_forward(self, model_inputs, *args, **kwargs):
# Implement pipelines __del__ to clean up poplar exector
def _del(self):
# For text generation models, deallocate the internal poplar executor
if hasattr(self.model, "poptorch_model"):
self.model.poptorch_model.destroy()
if hasattr(self.model, "poptorch_decoder"):
self.model.poptorch_decoder.destroy()
if hasattr(self.model, "poptorch_encoder"):
self.model.poptorch_encoder.destroy()

pipeline_class.__del__ = _del

Expand Down