diff --git a/paddlenlp/transformers/roformerv2/modeling.py b/paddlenlp/transformers/roformerv2/modeling.py index 857afe2802e1..f41f3354e548 100644 --- a/paddlenlp/transformers/roformerv2/modeling.py +++ b/paddlenlp/transformers/roformerv2/modeling.py @@ -520,6 +520,12 @@ def forward(self, return outputs + def get_input_embeddings(self) -> nn.Embedding: + return self.embeddings.word_embeddings + + def set_input_embeddings(self, embedding: nn.Embedding): + self.embeddings.word_embeddings = embedding + class RoFormerv2ForQuestionAnswering(RoFormerv2PretrainedModel): """ diff --git a/paddlenlp/transformers/roformerv2/tokenizer.py b/paddlenlp/transformers/roformerv2/tokenizer.py index a62153bc9040..a96ce67cfe5a 100644 --- a/paddlenlp/transformers/roformerv2/tokenizer.py +++ b/paddlenlp/transformers/roformerv2/tokenizer.py @@ -100,6 +100,13 @@ class RoFormerv2Tokenizer(PretrainedTokenizer): "do_lower_case": True }, } + + # TODO(wj-Mcat): to be confirmed + max_model_input_sizes = { + "roformer_v2_chinese_char_small": 1024, + "roformer_v2_chinese_char_base": 1024, + "roformer_v2_chinese_char_large": 1024, + } padding_side = "right" max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES diff --git a/tests/transformers/roformerv2/__init__.py b/tests/transformers/roformerv2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/transformers/roformerv2/test_modeling.py b/tests/transformers/roformerv2/test_modeling.py new file mode 100644 index 000000000000..3d043d0c0794 --- /dev/null +++ b/tests/transformers/roformerv2/test_modeling.py @@ -0,0 +1,348 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional, Tuple +from dataclasses import dataclass, fields, Field + +import paddle +from paddlenlp.transformers import ( + RoFormerv2Model, + RoFormerv2ForMaskedLM, + RoFormerv2PretrainedModel, + RoFormerv2ForSequenceClassification, + RoFormerv2ForTokenClassification, + RoFormerv2ForQuestionAnswering, + RoFormerv2ForMultipleChoice, +) + +from ..test_modeling_common import ids_tensor, floats_tensor, random_attention_mask, ModelTesterMixin +from ...testing_utils import slow + + +@dataclass +class RoFormerv2ModelTestModelConfig: + """RoFormerv2Model model config which keep consist with pretrained_init_configuration sub fields + """ + vocab_size: int = 200 + hidden_size: int = 36 + num_hidden_layers: int = 6 + num_attention_heads: int = 6 + intermediate_size: int = 20 + hidden_act: str = "relu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + max_position_embeddings: int = 20 + type_vocab_size: int = 2 + pad_token_id: int = 0 + rotary_value: bool = False + use_bias: bool = False + + @property + def model_kwargs(self) -> dict: + """get the model kwargs configuration to init the model""" + model_config_fields: Tuple[Field, + ...] = fields(RoFormerv2ModelTestModelConfig) + return { + field.name: getattr(self, field.name) + for field in model_config_fields + } + + +@dataclass +class RoFormerv2ModelTestConfig(RoFormerv2ModelTestModelConfig): + """train config under unittest code""" + batch_size: int = 2 + seq_length: int = 7 + is_training: bool = False + use_input_mask: bool = False + use_token_type_ids: bool = True + + # used for sequence classification + num_classes: int = 3 + num_choices: int = 3 + + +class RoFormerv2ModelTester: + + def __init__( + self, + parent, + config: Optional[RoFormerv2ModelTestConfig] = None, + ): + self.parent = parent + self.config: RoFormerv2ModelTestConfig = config or RoFormerv2ModelTestConfig( + ) + + self.is_training = self.config.is_training + self.num_classes = self.config.num_classes + self.num_choices = self.config.num_choices + + def prepare_config_and_inputs(self): + config = self.config + input_ids = ids_tensor([config.batch_size, config.seq_length], + config.vocab_size) + + input_mask = None + if self.config.use_input_mask: + input_mask = random_attention_mask( + [config.batch_size, config.seq_length]) + + token_type_ids = None + if self.config.use_token_type_ids: + token_type_ids = ids_tensor([config.batch_size, config.seq_length], + config.type_vocab_size) + + config = self.get_config() + return config, input_ids, token_type_ids, input_mask + + def get_config(self) -> dict: + return self.config.model_kwargs + + def create_and_check_model( + self, + config, + input_ids, + token_type_ids, + input_mask, + ): + model = RoFormerv2Model(**config) + model.eval() + result = model(input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + output_hidden_states=True) + result = model(input_ids, + token_type_ids=token_type_ids, + output_hidden_states=True) + result = model(input_ids, output_hidden_states=True) + self.parent.assertEqual(result[0].shape, [ + self.config.batch_size, self.config.seq_length, + self.config.hidden_size + ]) + self.parent.assertEqual(result[1].shape, [ + self.config.batch_size, self.config.seq_length, + self.config.hidden_size + ]) + + def create_and_check_for_multiple_choice( + self, + config, + input_ids, + token_type_ids, + input_mask, + ): + model = RoFormerv2ForMultipleChoice(RoFormerv2Model(**config), + num_choices=self.config.num_choices) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand( + [-1, self.config.num_choices, -1]) + + if token_type_ids is not None: + token_type_ids = token_type_ids.unsqueeze(1).expand( + [-1, self.config.num_choices, -1]) + + if input_mask is not None: + input_mask = input_mask.unsqueeze(1).expand( + [-1, self.config.num_choices, -1]) + + result = model( + multiple_choice_inputs_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + ) + self.parent.assertEqual( + result.shape, [self.config.batch_size, self.config.num_choices]) + + def create_and_check_for_masked_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + ): + model = RoFormerv2ForMaskedLM(RoFormerv2Model(**config)) + model.eval() + result = model(input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids) + self.parent.assertEqual(result.shape, [ + self.config.batch_size, self.config.seq_length, + self.config.vocab_size + ]) + + def create_and_check_for_sequence_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + ): + model = RoFormerv2ForSequenceClassification( + RoFormerv2Model(**config), num_classes=self.config.num_classes) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + ) + self.parent.assertEqual( + result.shape, [self.config.batch_size, self.config.num_classes]) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + ) = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask + } + return config, inputs_dict + + def create_and_check_for_question_answering(self, config, input_ids, + token_type_ids, input_mask): + model = RoFormerv2ForQuestionAnswering(RoFormerv2Model(**config)) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + ) + self.parent.assertEqual( + result[0].shape, [self.config.batch_size, self.config.seq_length]) + self.parent.assertEqual( + result[1].shape, [self.config.batch_size, self.config.seq_length]) + + def create_and_check_for_token_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + ): + model = RoFormerv2ForTokenClassification(RoFormerv2Model(**config), + num_classes=self.num_classes) + model.eval() + result = model(input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids) + self.parent.assertEqual(result.shape, [ + self.config.batch_size, self.config.seq_length, + self.config.num_classes + ]) + + +class RoFormerv2ModelTest(ModelTesterMixin, unittest.TestCase): + base_model_class = RoFormerv2Model + + all_model_classes = ( + RoFormerv2ForMaskedLM, + RoFormerv2ForSequenceClassification, + RoFormerv2ForTokenClassification, + RoFormerv2ForQuestionAnswering, + RoFormerv2ForMultipleChoice, + ) + + def setUp(self): + self.model_tester = RoFormerv2ModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice( + *config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering( + *config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification( + *config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification( + *config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in list( + RoFormerv2PretrainedModel.pretrained_init_configuration)[:1]: + model = RoFormerv2Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class RoFormerv2ModelIntegrationTest(unittest.TestCase): + + @slow + def test_inference_no_attention(self): + model = RoFormerv2Model.from_pretrained( + "roformer_v2_chinese_char_small") + model.eval() + input_ids = paddle.to_tensor( + [[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + with paddle.no_grad(): + output = model(input_ids, output_hidden_states=True)[0] + expected_shape = [1, 11, 384] + self.assertEqual(output.shape, expected_shape) + + expected_slice = paddle.to_tensor( + [[[0.75068903, 0.13977423, 0.07971212], + [0.08614583, 0.21606587, -1.08551681], + [0.98021960, -0.85751861, -1.42552316]]]) + + self.assertTrue( + paddle.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) + + @slow + def test_inference_with_attention(self): + model = RoFormerv2Model.from_pretrained( + "roformer_v2_chinese_char_small") + model.eval() + input_ids = paddle.to_tensor( + [[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + attention_mask = paddle.to_tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + with paddle.no_grad(): + output = model(input_ids, + attention_mask=attention_mask, + output_hidden_states=True)[0] + expected_shape = [1, 11, 384] + self.assertEqual(output.shape, expected_shape) + + expected_slice = paddle.to_tensor( + [[[0.75068903, 0.13977423, 0.07971212], + [0.08614583, 0.21606587, -1.08551681], + [0.98021960, -0.85751861, -1.42552316]]]) + self.assertTrue( + paddle.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transformers/roformerv2/test_tokenizer.py b/tests/transformers/roformerv2/test_tokenizer.py new file mode 100644 index 000000000000..e9d129ae33e2 --- /dev/null +++ b/tests/transformers/roformerv2/test_tokenizer.py @@ -0,0 +1,261 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from paddlenlp.data.vocab import Vocab + +from paddlenlp.transformers.roformerv2.tokenizer import (BasicTokenizer, + RoFormerv2Tokenizer, + WordpieceTokenizer) + +from tests.testing_utils import slow +from tests.transformers.test_tokenizer_common import TokenizerTesterMixin, filter_non_english + + +class RoFormerv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = RoFormerv2Tokenizer + space_between_special_tokens = True + from_pretrained_filter = filter_non_english + test_seq2seq = True + + def setUp(self): + self.from_pretrained_kwargs = {"do_lower_case": False} + + super().setUp() + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + + self.vocab_file = os.path.join( + self.tmpdirname, + RoFormerv2Tokenizer.resource_files_names["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + self.vocab = Vocab.from_dict( + {token: index + for index, token in enumerate(vocab_tokens)}, + unk_token='[UNK]', + pad_token='[PAD]', + bos_token='[CLS]', + eos_token='[SEP]', + ) + + def get_input_output_texts(self, tokenizer): + input_text = "UNwant\u00E9d,running" + output_text = "unwanted, running" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class(self.vocab_file) + + tokens = tokenizer.tokenize("UNwant\u00E9d,running") + self.assertListEqual(tokens, + ["un", "##want", "##ed", ",", "runn", "##ing"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), + [9, 6, 7, 12, 10, 11]) + + def test_chinese(self): + tokenizer = BasicTokenizer() + + self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), + ["ah", "\u535A", "\u63A8", "zz"]) + + def test_basic_tokenizer_lower(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual(tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), + ["hello", "!", "how", "are", "you", "?"]) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_lower_strip_accents_false(self): + tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False) + + self.assertListEqual(tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), + ["hällo", "!", "how", "are", "you", "?"]) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"]) + + def test_basic_tokenizer_lower_strip_accents_true(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual(tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), + ["hallo", "!", "how", "are", "you", "?"]) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_lower_strip_accents_default(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual(tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), + ["hallo", "!", "how", "are", "you", "?"]) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_no_lower(self): + tokenizer = BasicTokenizer(do_lower_case=False) + + self.assertListEqual(tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), + ["HeLLo", "!", "how", "Are", "yoU", "?"]) + + def test_basic_tokenizer_no_lower_strip_accents_false(self): + tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False) + + self.assertListEqual(tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), + ["HäLLo", "!", "how", "Are", "yoU", "?"]) + + def test_basic_tokenizer_no_lower_strip_accents_true(self): + tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True) + + self.assertListEqual(tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), + ["HaLLo", "!", "how", "Are", "yoU", "?"]) + + def test_basic_tokenizer_respects_never_split_tokens(self): + tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"]) + + self.assertListEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), + ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]) + + def test_wordpiece_tokenizer(self): + vocab_tokens = [ + "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", + "runn", "##ing" + ] + + vocab = {} + for (i, token) in enumerate(vocab_tokens): + vocab[token] = i + tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") + + self.assertListEqual(tokenizer.tokenize(""), []) + + self.assertListEqual(tokenizer.tokenize("unwanted running"), + ["un", "##want", "##ed", "runn", "##ing"]) + + self.assertListEqual(tokenizer.tokenize("unwantedX running"), + ["[UNK]", "runn", "##ing"]) + + def test_clean_text(self): + tokenizer = self.get_tokenizer() + + # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340 + self.assertListEqual( + [tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], + [["[UNK]"], [], ["[UNK]"]]) + + # @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained( + "roformer-chinese-small") + + text = tokenizer.encode("sequence builders", + return_token_type_ids=None, + add_special_tokens=False)["input_ids"] + text_2 = tokenizer.encode("multi-sequence build", + return_token_type_ids=None, + add_special_tokens=False)["input_ids"] + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + assert encoded_sentence == [101] + text + [102] + assert encoded_pair == [101] + text + [102] + text_2 + [102] + + def test_offsets_with_special_characters(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest( + f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer = self.tokenizer_class.from_pretrained( + pretrained_name, **kwargs) + + # sentence = f"testing with {tokenizer.mask_token} simple sentence" + sentence = f"a simple {tokenizer.mask_token} allennlp sentence." + tokens = tokenizer.encode( + sentence, + return_attention_mask=False, + return_token_type_ids=False, + return_offsets_mapping=True, + add_special_tokens=True, + ) + expected_results = [ + ((0, 0), tokenizer.cls_token), + ((0, 1), "a"), + ((2, 8), "simple"), + ((9, 15), tokenizer.mask_token), + ((16, 21), "allen"), + ((21, 23), "##nl"), + ((23, 24), "##p"), + ((25, 33), "sentence"), + ((33, 34), "."), + ((0, 0), tokenizer.sep_token), + ] + + self.assertEqual([e[1] for e in expected_results], + tokenizer.convert_ids_to_tokens( + tokens["input_ids"])) + self.assertEqual([e[0] for e in expected_results], + tokens["offset_mapping"]) + + def test_change_tokenize_chinese_chars(self): + list_of_commun_chinese_char = ["的", "人", "有"] + text_with_chinese_char = "".join(list_of_commun_chinese_char) + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest( + f"{tokenizer.__class__.__name__} ({pretrained_name})"): + + kwargs["tokenize_chinese_chars"] = True + tokenizer = self.tokenizer_class.from_pretrained( + pretrained_name, **kwargs) + + ids_without_spe_char_p = tokenizer.encode( + text_with_chinese_char, + return_token_type_ids=None, + add_special_tokens=False)["input_ids"] + + tokens_without_spe_char_p = tokenizer.convert_ids_to_tokens( + ids_without_spe_char_p) + + # it is expected that each Chinese character is not preceded by "##" + self.assertListEqual(tokens_without_spe_char_p, + list_of_commun_chinese_char) + ''' + kwargs["tokenize_chinese_chars"] = False + tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + ids_without_spe_char_p = tokenizer.encode(text_with_chinese_char, return_token_type_ids=None,add_special_tokens=False)["input_ids"] + + tokens_without_spe_char_p = tokenizer.convert_ids_to_tokens(ids_without_spe_char_p) + + # it is expected that only the first Chinese character is not preceded by "##". + expected_tokens = [ + f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char) + ] + self.assertListEqual(tokens_without_spe_char_p, expected_tokens) + '''