From cdc8b74c533f71b8b25df62dde67ddfff611f8e2 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Thu, 25 Aug 2022 14:12:38 +0800 Subject: [PATCH 1/3] update return_dict/label in skep model --- model_zoo/ernie-1.0/run_pretrain.py | 56 ++++-- paddlenlp/transformers/skep/modeling.py | 216 ++++++++++++++++++++---- 2 files changed, 225 insertions(+), 47 deletions(-) diff --git a/model_zoo/ernie-1.0/run_pretrain.py b/model_zoo/ernie-1.0/run_pretrain.py index 2e47f8bd89b1..d6bb1cfccc38 100644 --- a/model_zoo/ernie-1.0/run_pretrain.py +++ b/model_zoo/ernie-1.0/run_pretrain.py @@ -15,6 +15,7 @@ ERNIE-1.0 pretraining scripts. """ import argparse +import contextlib import os import sys import random @@ -66,6 +67,7 @@ def create_pretrained_dataset( args.eval_iters * data_world_size, args.micro_batch_size * args.test_iters * data_world_size ] + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( data_prefix=data_file, args=args, @@ -291,6 +293,7 @@ def args_post_process(args, worker_num): "cannot do gradient accumulate, global_batch_size: {} micro_batch_size: {}".format( args.global_batch_size, micro_batch_size) accumulate_steps = bsz_per_dp // micro_batch_size + assert accumulate_steps >= 1, f"Larger global_batch_size: {arg.global_batch_size} is expect, micro_batch_size is {micro_batch_size}, but only {bsz_per_dp} on each card!" args.eval_iters *= accumulate_steps args.test_iters *= accumulate_steps @@ -451,6 +454,7 @@ def do_train(args): optimizer = fleet.distributed_optimizer(optimizer) tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name_or_path) + # Must extend chinese char for ErnieTokenizer tokenizer.extend_chinese_char() data_file = get_train_data_file(args) @@ -461,6 +465,7 @@ def do_train(args): data_world_size=worker_num, data_world_rank=worker_index, max_seq_len=args.max_seq_len, + binary_head=args.binary_head, current_step=global_step) # load checkpoint vars @@ -516,6 +521,7 @@ def do_train(args): # time count train_reader_cost = 0.0 train_run_cost = 0.0 + tr_loss = paddle.to_tensor(0.0) reader_start = time.time() for step, batch in enumerate(train_data_loader()): @@ -532,7 +538,17 @@ def do_train(args): input_ids, segment_ids, input_mask, masked_lm_positions, \ masked_lm_labels, next_sentence_labels = batch - with model.no_sync(): + ctx_manager = contextlib.nullcontext() if sys.version_info >= ( + 3, 7) else contextlib.suppress() + + if worker_num > 1 and (args.use_recompute + or args.accumulate_steps > 1): + ctx_manager = model.no_sync() + else: + ctx_manager = contextlib.nullcontext() if sys.version_info >= ( + 3, 7) else contextlib.suppress() + + with ctx_manager: with paddle.amp.auto_cast(args.use_amp, custom_white_list=[ 'softmax', @@ -568,37 +584,47 @@ def do_train(args): loss = criterion(prediction_scores, None, masked_lm_labels) + if args.accumulate_steps >= 1: + tr_loss_step = loss / args.accumulate_steps + else: + tr_loss_step = loss + if args.use_amp: - scaler.scale(loss).backward() + scaler.scale(tr_loss_step).backward() else: - loss.backward() + tr_loss_step.backward() - fused_allreduce_gradients(list(model.parameters()), None) + tr_loss += tr_loss_step + + loss_global["loss"] += loss.detach() + if args.binary_head: + loss_global["lm_loss"] += lm_loss.detach() + loss_global["sop_loss"] += sop_loss.detach() + + # Skip for accumulate_steps in global step + if (step + 1) % args.accumulate_steps != 0: + continue + + if worker_num > 1 and args.use_recompute: + fused_allreduce_gradients(list(model.parameters()), None) if args.use_amp: - scaler.minimize(optimizer, loss) + scaler.minimize(optimizer, tr_loss) else: optimizer.step() optimizer.clear_grad() train_run_cost += time.time() - train_start - - # Skip for accumulate_steps in global step - if (step + 1) % args.accumulate_steps != 0: - continue + tr_loss.subtract_(tr_loss) global_step += 1 - loss_global["loss"] += loss.detach() - if args.binary_head: - loss_global["lm_loss"] += lm_loss.detach() - loss_global["sop_loss"] += sop_loss.detach() - if global_step % args.logging_freq == 0: log_info_dict = dict() log_info_dict["global_step"] = global_step for k, v in loss_global.items(): - log_info_dict[k] = all_gather(v) / args.logging_freq + log_info_dict[k] = all_gather( + v) / args.logging_freq / args.accumulate_steps v.subtract_(v) if worker_index == 0: speed = args.logging_freq / (time.time() - tic_train) diff --git a/paddlenlp/transformers/skep/modeling.py b/paddlenlp/transformers/skep/modeling.py index a65da0af5acc..10d0f67f618a 100644 --- a/paddlenlp/transformers/skep/modeling.py +++ b/paddlenlp/transformers/skep/modeling.py @@ -284,7 +284,10 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=None): + attention_mask=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): r""" The SkepModel forward method, overrides the `__call__()` special method. @@ -319,9 +322,23 @@ def forward(self, For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], [batch_size, num_attention_heads, sequence_length, sequence_length]. Defaults to `None`, which means nothing needed to be prevented attention to. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output + will be a tuple of tensors. Defaults to `False`. Returns: - tuple: Returns tuple (`sequence_output`, `pooled_output`). + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if + `return_dict=True`. Otherwise it returns a tuple of tensors corresponding + to ordered and not None (depending on the input arguments) fields of + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`. + + if the reuslt is tuple: Returns tuple (`sequence_output`, `pooled_output`). With the fields: @@ -356,10 +373,28 @@ def forward(self, embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids) - encoder_outputs = self.encoder(embedding_output, attention_mask) - sequence_output = encoder_outputs - pooled_output = self.pooler(sequence_output) - return sequence_output, pooled_output + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + if isinstance(encoder_outputs, type(embedding_output)): + sequence_output = encoder_outputs + pooled_output = self.pooler(sequence_output) + return (sequence_output, pooled_output) + else: + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions) def get_input_embeddings(self) -> nn.Embedding: """get skep input word embedding @@ -409,7 +444,11 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=None): + attention_mask=None, + labels=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): r""" The SkepForSequenceClassification forward method, overrides the __call__() special method. @@ -422,10 +461,25 @@ def forward(self, See :class:`SkepModel`. attention_mask (Tensor, optional): See :class:`SkepModel`. + labels (Tensor of shape `(batch_size,)`, optional): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1` + a regression loss is computed (Mean-Square loss), If `num_classes > 1` + a classification loss is computed (Cross-Entropy). + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `logits`, a tensor of the input text classification logits. - Shape as `[batch_size, num_classes]` and dtype as float32. + An instance of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput`. Example: .. code-block:: @@ -441,14 +495,43 @@ def forward(self, logits = model(**inputs) """ - _, pooled_output = self.skep(input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) + outputs = self.skep(input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) - return logits + + loss = None + if labels is not None: + if self.num_classes == 1: + loss_fct = paddle.nn.MSELoss() + loss = loss_fct(logits, labels) + elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) + else: + loss_fct = paddle.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else ( + output[0] if len(output) == 1 else output) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) class SkepForTokenClassification(SkepPretrainedModel): @@ -482,7 +565,11 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=None): + attention_mask=None, + labels=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): r""" The SkepForTokenClassification forward method, overrides the __call__() special method. @@ -495,10 +582,22 @@ def forward(self, See :class:`SkepModel`. attention_mask (Tensor, optional): See :class:`SkepModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `logits`, a tensor of the input token classification logits. - Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. + An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. Example: .. code-block:: @@ -514,14 +613,36 @@ def forward(self, logits = model(**inputs) """ - sequence_output, _ = self.skep(input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) + outputs = self.skep(input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - return logits + + loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) + + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else ( + output[0] if len(output) == 1 else output) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) class SkepCrfForTokenClassification(SkepPretrainedModel): @@ -564,7 +685,10 @@ def forward(self, position_ids=None, attention_mask=None, seq_lens=None, - labels=None): + labels=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): r""" The SkepCrfForTokenClassification forward method, overrides the __call__() special method. @@ -584,9 +708,22 @@ def forward(self, labels (Tensor, optional): The input label tensor. Its data type should be int64 and its shape is `[batch_size, sequence_length]`. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `loss` if `labels` is not None. Otherwise, returns tensor `prediction`. + An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. + + if return_dict is False, Returns tensor `loss` if `labels` is not None. Otherwise, returns tensor `prediction`. - `loss` (Tensor): The crf loss. Its data type is float32 and its shape is `[batch_size]`. @@ -596,13 +733,15 @@ def forward(self, Its data type is int64 and its shape is `[batch_size, sequence_length]`. """ - sequence_output, _ = self.skep(input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) - - bigru_output, _ = self.gru( - sequence_output) #, sequence_length=seq_lens) + outputs = self.skep(input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + bigru_output, _ = self.gru(outputs[0]) #, sequence_length=seq_lens) emission = self.fc(bigru_output) if seq_lens is None: @@ -616,9 +755,22 @@ def forward(self, seq_lens = paddle.ones(shape=[input_ids_shape[0]], dtype=paddle.int64) * input_ids_shape[1] + loss, prediction = None, None if labels is not None: loss = self.crf_loss(emission, seq_lens, labels) - return loss else: _, prediction = self.viterbi_decoder(emission, seq_lens) + + # FIXME(wj-Mcat): the output of this old version model is single tensor when return_dict is False + if not return_dict: + # when loss is None, return prediction + if labels is not None: + return loss return prediction + + return TokenClassifierOutput( + loss=loss, + logits=prediction, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) From 0500a18e8e1ba0989627493ff08e1b49b66faac3 Mon Sep 17 00:00:00 2001 From: wj-Mcat <1435130236@qq.com> Date: Thu, 25 Aug 2022 13:13:50 +0000 Subject: [PATCH 2/3] complete skep add-more-output --- paddlenlp/transformers/skep/modeling.py | 23 +++- tests/transformers/skep/test_modeling.py | 144 +++++++++++++++-------- 2 files changed, 113 insertions(+), 54 deletions(-) diff --git a/paddlenlp/transformers/skep/modeling.py b/paddlenlp/transformers/skep/modeling.py index 10d0f67f618a..d18b3bcdcccd 100644 --- a/paddlenlp/transformers/skep/modeling.py +++ b/paddlenlp/transformers/skep/modeling.py @@ -25,6 +25,15 @@ else: from paddlenlp.layers.crf import ViterbiDecoder +from ..model_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + SequenceClassifierOutput, + TokenClassifierOutput, + QuestionAnsweringModelOutput, + MultipleChoiceModelOutput, + MaskedLMOutput, + CausalLMOutputWithCrossAttentions, +) from .. import PretrainedModel, register_base_model __all__ = [ @@ -523,8 +532,11 @@ def forward(self, if not return_dict: output = (logits, ) + outputs[2:] - return ((loss, ) + output) if loss is not None else ( - output[0] if len(output) == 1 else output) + if loss is not None: + return (loss, ) + output + if len(output) == 1: + return output[0] + return output return SequenceClassifierOutput( loss=loss, @@ -634,8 +646,11 @@ def forward(self, if not return_dict: output = (logits, ) + outputs[2:] - return ((loss, ) + output) if loss is not None else ( - output[0] if len(output) == 1 else output) + if loss is not None: + return (loss, ) + output + if len(output) == 1: + return output[0] + return output return TokenClassifierOutput( loss=loss, diff --git a/tests/transformers/skep/test_modeling.py b/tests/transformers/skep/test_modeling.py index 03e2ed87cefe..4591a5282859 100644 --- a/tests/transformers/skep/test_modeling.py +++ b/tests/transformers/skep/test_modeling.py @@ -17,6 +17,7 @@ from typing import Optional, Tuple, Dict, Any import paddle from paddle import Tensor +from parameterized import parameterized_class from dataclasses import dataclass, asdict, fields, Field from paddlenlp.transformers import ( @@ -70,6 +71,8 @@ class SkepTestConfig(SkepTestModelConfig): # used for sequence classification num_classes: int = 3 + num_choices: int = 3 + type_sequence_label_size: int = 3 class SkepModelTester: @@ -82,6 +85,11 @@ def __init__(self, parent, config: Optional[SkepTestConfig] = None): self.is_training = self.config.is_training + def __getattr__(self, key: str): + if not hasattr(self.config, key): + raise AttributeError(f'attribute <{key}> not exist') + return getattr(self.config, key) + def prepare_config_and_inputs( self) -> Tuple[Dict[str, Any], Tensor, Tensor, Tensor]: config = self.config @@ -98,23 +106,36 @@ def prepare_config_and_inputs( token_type_ids = ids_tensor([config.batch_size, config.seq_length], config.type_vocab_size) - return config.model_kwargs, input_ids, token_type_ids, input_mask + sequence_labels = None + token_labels = None + choice_labels = None + + if self.parent.use_labels: + sequence_labels = ids_tensor([self.batch_size], + self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], + self.num_classes) + choice_labels = ids_tensor([self.batch_size], self.num_choices) - def create_and_check_model( - self, - config, - input_ids: Tensor, - token_type_ids: Tensor, - input_mask: Tensor, - ): + config = self.get_config() + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def create_and_check_model(self, config, input_ids: Tensor, + token_type_ids: Tensor, input_mask: Tensor, + sequence_labels: Tensor, token_labels: Tensor, + choice_labels: Tensor): model = SkepModel(**config) model.eval() result = model(input_ids, attention_mask=input_mask, - token_type_ids=token_type_ids) - result = model(input_ids, token_type_ids=token_type_ids) - result = model(input_ids) + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict) + result = model(input_ids, + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict) + result = model(input_ids, return_dict=self.parent.return_dict) + self.parent.assertEqual(result[0].shape, [ self.config.batch_size, self.config.seq_length, self.config.hidden_size @@ -123,60 +144,83 @@ def create_and_check_model( result[1].shape, [self.config.batch_size, self.config.hidden_size]) def create_and_check_for_sequence_classification( - self, - config, - input_ids: Tensor, - token_type_ids: Tensor, - input_mask: Tensor, - ): + self, config, input_ids: Tensor, token_type_ids: Tensor, + input_mask: Tensor, sequence_labels: Tensor, token_labels: Tensor, + choice_labels: Tensor): model = SkepForSequenceClassification( SkepModel(**config), num_classes=self.config.num_classes) model.eval() - result = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - ) + result = model(input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict, + labels=sequence_labels) + + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + self.parent.assertEqual( - result.shape, [self.config.batch_size, self.config.num_classes]) + result[0].shape, [self.config.batch_size, self.config.num_classes]) def create_and_check_for_token_classification( - self, - config, - input_ids, - token_type_ids, - input_mask, - ): + self, config, input_ids: Tensor, token_type_ids: Tensor, + input_mask: Tensor, sequence_labels: Tensor, token_labels: Tensor, + choice_labels: Tensor): model = SkepForTokenClassification(SkepModel(**config), num_classes=self.config.num_classes) model.eval() result = model(input_ids, attention_mask=input_mask, - token_type_ids=token_type_ids) - self.parent.assertEqual(result.shape, [ + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict, + labels=token_labels) + + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [ self.config.batch_size, self.config.seq_length, self.config.num_classes ]) def create_and_check_for_crf_token_classification( - self, - config, - input_ids, - token_type_ids, - input_mask, - ): + self, config, input_ids: Tensor, token_type_ids: Tensor, + input_mask: Tensor, sequence_labels: Tensor, token_labels: Tensor, + choice_labels: Tensor): model = SkepCrfForTokenClassification( SkepModel(**config), num_classes=self.config.num_classes) model.eval() result = model(input_ids, attention_mask=input_mask, - token_type_ids=token_type_ids) - self.parent.assertEqual( - result.shape, [self.config.batch_size, self.config.seq_length]) + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict, + labels=token_labels) + # TODO(wj-Mcat): the output of SkepCrfForTokenClassification is wrong + if paddle.is_tensor(result): + result = [result] + + if token_labels is not None: + self.parent.assertEqual(result[0].shape, [self.config.batch_size]) + else: + self.parent.assertEqual( + result[0].shape, + [self.config.batch_size, self.config.seq_length]) def prepare_config_and_inputs_for_common(self): - config, input_ids, token_type_ids, input_mask = self.prepare_config_and_inputs( - ) + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs inputs_dict = { "input_ids": input_ids, "token_type_ids": token_type_ids, @@ -193,8 +237,16 @@ def get_config(self) -> dict: return self.config.model_kwargs +@parameterized_class(("return_dict", "use_labels"), [ + [False, False], + [False, True], + [True, False], + [True, True], +]) class SkepModelTest(ModelTesterMixin, unittest.TestCase): base_model_class = SkepModel + return_dict = False + use_labels = False all_model_classes = ( SkepModel, @@ -207,9 +259,6 @@ class SkepModelTest(ModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = SkepModelTester(self) - def get_config(): - pass - def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) @@ -225,11 +274,6 @@ def test_for_token_classification(self): *config_and_inputs) def test_for_crf_token_classification(self): - # TODO(wj-Mcat): to activate this method later - # self.skipTest( - # "skip for crf token classification: there are contains something wrong in paddle.text.viterib_decode" - # ) - # return config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_crf_token_classification( *config_and_inputs) From fc7976162fc08d02234d74fbedbb69c38c5f19e5 Mon Sep 17 00:00:00 2001 From: wj-Mcat <1435130236@qq.com> Date: Tue, 30 Aug 2022 11:16:37 +0000 Subject: [PATCH 3/3] refactor simple code --- paddlenlp/transformers/skep/modeling.py | 28 +++++++++++------------- tests/transformers/skep/test_modeling.py | 1 - 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/paddlenlp/transformers/skep/modeling.py b/paddlenlp/transformers/skep/modeling.py index d18b3bcdcccd..9b1ddd71e5e2 100644 --- a/paddlenlp/transformers/skep/modeling.py +++ b/paddlenlp/transformers/skep/modeling.py @@ -389,21 +389,19 @@ def forward(self, output_hidden_states=output_hidden_states, return_dict=return_dict) - if isinstance(encoder_outputs, type(embedding_output)): - sequence_output = encoder_outputs - pooled_output = self.pooler(sequence_output) - return (sequence_output, pooled_output) - else: - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions) + if paddle.is_tensor(encoder_outputs): + encoder_outputs = (encoder_outputs, ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions) def get_input_embeddings(self) -> nn.Embedding: """get skep input word embedding diff --git a/tests/transformers/skep/test_modeling.py b/tests/transformers/skep/test_modeling.py index 4591a5282859..b3016eaf2c58 100644 --- a/tests/transformers/skep/test_modeling.py +++ b/tests/transformers/skep/test_modeling.py @@ -250,7 +250,6 @@ class SkepModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( SkepModel, - # TODO(wj-Mcat): to activate this model later SkepCrfForTokenClassification, SkepForSequenceClassification, SkepForTokenClassification,