Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine T5 tests #3266

Merged
merged 1 commit into from
Sep 14, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions tests/transformers/t5/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def test_small_generation(self):

sequences = model.generate(input_ids,
max_length=8,
decode_strategy="greedy_search")
decode_strategy="greedy_search")[0]

output_str = tokenizer.batch_decode(sequences,
skip_special_tokens=True)[0]
Expand Down Expand Up @@ -659,7 +659,7 @@ def test_small_v1_1_integration_test(self):
model.eval()

input_ids = tokenizer("Hello there", return_tensors="pd")["input_ids"]
labels = tokenizer("Hi I am", return_tensors="pt")["input_ids"]
labels = tokenizer("Hi I am", return_tensors="pd")["input_ids"]

loss = model(input_ids, labels=labels)[0]
mtf_score = -(labels.shape[-1] * loss.item())
Expand All @@ -669,9 +669,9 @@ def test_small_v1_1_integration_test(self):

@slow
def test_summarization(self):
model = self.model
model = self.model()
model.eval()
tok = self.tokenizer
tok = self.tokenizer()

FRANCE_ARTICLE = ( # @noqa
"Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
Expand Down Expand Up @@ -860,19 +860,10 @@ def test_summarization(self):
)

expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
" magazine says .",
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well .",
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
" the debate that has already begun since the announcement of the new framework will likely result in more"
" heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
" implement a rigorous inspection regime .",
"prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
" times, with nine of her marriages occurring between 1999 and 2002 .",
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says . all 150 on board were killed in the crash .',
'the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .',
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut centrifuges . miller: if it had been, there would have been no Iranian team at the negotiating table .",
'prosecutors say the marriages were part of an immigration scam . barrientos pleaded not guilty to two counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
]

dct = tok(
Expand Down Expand Up @@ -907,13 +898,13 @@ def test_summarization(self):

@slow
def test_translation_en_to_de(self):
model = self.model
tok = self.tokenizer
use_task_specific_params(model, "translation_en_to_de")
model = self.model()
model.eval()
tok = self.tokenizer()

en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.'
expected_translation = (
'"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.'
'"Luigi sagte mir oft, er wollte nie, dass die Brüder am Gericht enden", schrieb sie.'
)

input_ids = tok.encode("translate English to German: " + en_text,
Expand All @@ -928,8 +919,9 @@ def test_translation_en_to_de(self):

@slow
def test_translation_en_to_fr(self):
model = self.model # t5-base
tok = self.tokenizer
model = self.model() # t5-base
model.eval()
tok = self.tokenizer()

en_text = (
' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of'
Expand All @@ -950,22 +942,20 @@ def test_translation_en_to_fr(self):
translation = tok.decode(output[0][0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
new_truncated_translation = (
"Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre "
"un "
"« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées "
"sous forme "
"de points bleus.")
new_truncated_translation = [
"Cette section d'images d'un enregistrement infrarouge du télescope Spitzer montre un « portrait familial » d'innombrables générations d'étoiles : les étoiles les plus anciennes sont visibles sous forme de points bleus."
]

self.assertEqual(translation, new_truncated_translation)
self.assertEqual(translation, new_truncated_translation[0])

@slow
def test_translation_en_to_ro(self):
model = self.model
tok = self.tokenizer
model = self.model()
model.eval()
tok = self.tokenizer()

en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022."
expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
expected_translation = 'Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022.'

input_ids = tok("translate English to Romanian: " + en_text,
return_tensors="pd")["input_ids"]
Expand Down