From 35f3875b824e1dc3a6acdbda992160cd20c3ddca Mon Sep 17 00:00:00 2001 From: FrostML <380185688@qq.com> Date: Wed, 14 Sep 2022 07:49:10 +0000 Subject: [PATCH] update t5 tests --- tests/transformers/t5/test_modeling.py | 56 +++++++++++--------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/tests/transformers/t5/test_modeling.py b/tests/transformers/t5/test_modeling.py index ee14f4598f40..8ca7c882e29e 100644 --- a/tests/transformers/t5/test_modeling.py +++ b/tests/transformers/t5/test_modeling.py @@ -627,7 +627,7 @@ def test_small_generation(self): sequences = model.generate(input_ids, max_length=8, - decode_strategy="greedy_search") + decode_strategy="greedy_search")[0] output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] @@ -659,7 +659,7 @@ def test_small_v1_1_integration_test(self): model.eval() input_ids = tokenizer("Hello there", return_tensors="pd")["input_ids"] - labels = tokenizer("Hi I am", return_tensors="pt")["input_ids"] + labels = tokenizer("Hi I am", return_tensors="pd")["input_ids"] loss = model(input_ids, labels=labels)[0] mtf_score = -(labels.shape[-1] * loss.item()) @@ -669,9 +669,9 @@ def test_small_v1_1_integration_test(self): @slow def test_summarization(self): - model = self.model + model = self.model() model.eval() - tok = self.tokenizer + tok = self.tokenizer() FRANCE_ARTICLE = ( # @noqa "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" @@ -860,19 +860,10 @@ def test_summarization(self): ) expected_summaries = [ - 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' - " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" - " magazine says .", - "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" - " preliminary examination into the situation in the occupied Palestinian territory . as members of the" - " court, Palestinians may be subject to counter-charges as well .", - "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" - " the debate that has already begun since the announcement of the new framework will likely result in more" - " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" - " implement a rigorous inspection regime .", - "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" - ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' - " times, with nine of her marriages occurring between 1999 and 2002 .", + 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says . all 150 on board were killed in the crash .', + 'the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .', + "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut centrifuges . miller: if it had been, there would have been no Iranian team at the negotiating table .", + 'prosecutors say the marriages were part of an immigration scam . barrientos pleaded not guilty to two counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .', ] dct = tok( @@ -907,13 +898,13 @@ def test_summarization(self): @slow def test_translation_en_to_de(self): - model = self.model - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_de") + model = self.model() + model.eval() + tok = self.tokenizer() en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' expected_translation = ( - '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' + '"Luigi sagte mir oft, er wollte nie, dass die Brüder am Gericht enden", schrieb sie.' ) input_ids = tok.encode("translate English to German: " + en_text, @@ -928,8 +919,9 @@ def test_translation_en_to_de(self): @slow def test_translation_en_to_fr(self): - model = self.model # t5-base - tok = self.tokenizer + model = self.model() # t5-base + model.eval() + tok = self.tokenizer() en_text = ( ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' @@ -950,22 +942,20 @@ def test_translation_en_to_fr(self): translation = tok.decode(output[0][0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - new_truncated_translation = ( - "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " - "un " - "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " - "sous forme " - "de points bleus.") + new_truncated_translation = [ + "Cette section d'images d'un enregistrement infrarouge du télescope Spitzer montre un « portrait familial » d'innombrables générations d'étoiles : les étoiles les plus anciennes sont visibles sous forme de points bleus." + ] - self.assertEqual(translation, new_truncated_translation) + self.assertEqual(translation, new_truncated_translation[0]) @slow def test_translation_en_to_ro(self): - model = self.model - tok = self.tokenizer + model = self.model() + model.eval() + tok = self.tokenizer() en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022." - expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." + expected_translation = 'Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022.' input_ids = tok("translate English to Romanian: " + en_text, return_tensors="pd")["input_ids"]