Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing return type tensor with num_return_sequences>1. #16828

Merged
merged 2 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_token
records = []
for output_ids in model_outputs["output_ids"][0]:
if return_type == ReturnType.TENSORS:
record = {f"{self.return_name}_token_ids": model_outputs}
record = {f"{self.return_name}_token_ids": output_ids}
elif return_type == ReturnType.TEXT:
record = {
f"{self.return_name}_text": self.tokenizer.decode(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_
records = []
for sequence in generated_sequence:
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": generated_sequence}
record = {"generated_token_ids": sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Decode text
text = self.tokenizer.decode(
Expand Down
33 changes: 33 additions & 0 deletions tests/pipelines/test_pipelines_text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,39 @@ def test_small_model_pt(self):
]
self.assertEqual(outputs, target_outputs)

import torch

outputs = generator("This is a test", do_sample=True, num_return_sequences=2, return_tensors=True)
self.assertEqual(
outputs,
[
{"generated_token_ids": ANY(torch.Tensor)},
{"generated_token_ids": ANY(torch.Tensor)},
],
)
generator.tokenizer.pad_token_id = generator.model.config.eos_token_id
generator.tokenizer.pad_token = "<pad>"
outputs = generator(
["This is a test", "This is a second test"],
do_sample=True,
num_return_sequences=2,
batch_size=2,
return_tensors=True,
)
self.assertEqual(
outputs,
[
[
{"generated_token_ids": ANY(torch.Tensor)},
{"generated_token_ids": ANY(torch.Tensor)},
],
[
{"generated_token_ids": ANY(torch.Tensor)},
{"generated_token_ids": ANY(torch.Tensor)},
],
],
)

@require_tf
def test_small_model_tf(self):
generator = pipeline("text2text-generation", model="patrickvonplaten/t5-tiny-random", framework="tf")
Expand Down
31 changes: 31 additions & 0 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,37 @@ def test_small_model_pt(self):
],
)

outputs = text_generator("This is a test", do_sample=True, num_return_sequences=2, return_tensors=True)
self.assertEqual(
outputs,
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
)
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
text_generator.tokenizer.pad_token = "<pad>"
outputs = text_generator(
["This is a test", "This is a second test"],
do_sample=True,
num_return_sequences=2,
batch_size=2,
return_tensors=True,
)
self.assertEqual(
outputs,
[
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
],
)

@require_tf
def test_small_model_tf(self):
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
Expand Down