Skip to content

Commit

Permalink
Disallow num_return_sequences > 1 #9
Browse files Browse the repository at this point in the history
  • Loading branch information
jvamvas committed May 2, 2024
1 parent 8bf38e2 commit 5606ec1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/mbr/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def generate(
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

if generation_config.num_return_sequences > 1:
raise ValueError("MBR decoding does not support `num_return_sequences` > 1.")

if references_config is not None:
references_config.validate()

Expand Down
32 changes: 30 additions & 2 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def test_generation_config(self):
generation_config=generation_config,
mbr_config=mbr_config,
tokenizer=self.tokenizer,
do_sample=True,
)
output = self.tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(1, len(output))
Expand Down Expand Up @@ -179,12 +178,41 @@ def test_references_config(self):
references_config=references_config,
mbr_config=mbr_config,
tokenizer=self.tokenizer,
do_sample=True,
)
output = self.tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(1, len(output))
self.assertTrue(output[0].startswith("Hello, my name is"))

def test_num_return_sequences(self):
"""
generation_config.num_return_sequences > 1 is not supported for MBR; should raise an error.
"""
mbr_config = MBRConfig(
num_samples=5,
)
generation_config = GenerationConfig.from_pretrained("distilgpt2",
do_sample=True,
num_return_sequences=2,
)
input_sentences = [
"Hello, my name is",
]
encoding = self.tokenizer(input_sentences, return_tensors="pt")
with self.assertRaises(ValueError):
self.model.generate(
**encoding,
generation_config=generation_config,
mbr_config=mbr_config,
tokenizer=self.tokenizer,
)
with self.assertRaises(ValueError):
self.model.generate(
**encoding,
mbr_config=mbr_config,
tokenizer=self.tokenizer,
num_return_sequences=2,
)


@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Requires extra dependencies")
class EncoderDecoderTestCase(TestCase):
Expand Down

0 comments on commit 5606ec1

Please sign in to comment.