Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable generation tests #407

Merged
merged 10 commits into from
Jun 8, 2023
26 changes: 22 additions & 4 deletions optimum/graphcore/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@ def detachFromDevice(self):
if hasattr(self, "poptorch_decoder"):
self.poptorch_decoder.detachFromDevice()

def destroy(self):
if hasattr(self, "poptorch_encoder"):
self.poptorch_encoder.destroy()
delattr(self, "poptorch_encoder")
if hasattr(self, "poptorch_decoder"):
self.poptorch_decoder.destroy()
delattr(self, "poptorch_decoder")

def _get_generation_step_tensor(self, generation_step, ascending=False):
# Returns a 1 dimensional tensor of the form [device_iterations * replication factor]
# with all elements equal to generation_step.
Expand Down Expand Up @@ -433,6 +441,9 @@ def greedy_search(
max_length = stopping_criteria.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -547,8 +558,10 @@ def greedy_search(
cur_len = cur_len + 1

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)

# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
Expand Down Expand Up @@ -1030,6 +1043,9 @@ def sample(
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1139,8 +1155,10 @@ def sample(
cur_len = cur_len + 1

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)

# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
Expand Down
Loading