From 222fc3154ab52d8b28c0b37edd5ce31db29b9e88 Mon Sep 17 00:00:00 2001 From: gongenlei Date: Fri, 17 Jun 2022 18:50:07 +0800 Subject: [PATCH] [FasterBART] Suppport dygraph to static and inference (#2519) * suppport d2s and inference * default to beam_search * add comments * update tokenizer usage * rm disable_faster_encoder --- .../sample/bart_decoding_sample.py | 18 +-- .../sample/bart_export_model_sample.py | 140 ++++++++++++++++++ .../sample/bart_inference.py | 112 ++++++++++++++ .../transformer/decoding.py | 3 +- .../transformer/faster_transformer.py | 13 +- paddlenlp/transformers/bart/modeling.py | 14 +- 6 files changed, 283 insertions(+), 17 deletions(-) create mode 100644 paddlenlp/ops/faster_transformer/sample/bart_export_model_sample.py create mode 100644 paddlenlp/ops/faster_transformer/sample/bart_inference.py diff --git a/paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py b/paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py index 594f050b767d..05a9b7fd4873 100644 --- a/paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py +++ b/paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py @@ -38,11 +38,9 @@ def postprocess_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): return seq -def prepare_input(tokenizer, sentences, pad_id): - word_pad = Pad(pad_id, dtype="int64") - tokenized = tokenizer(sentences, return_length=True) - inputs = word_pad([i["input_ids"] for i in tokenized]) - input_ids = paddle.to_tensor(inputs) +def prepare_input(tokenizer, sentences): + tokenized = tokenizer(sentences, padding=True) + input_ids = paddle.to_tensor(tokenized['input_ids'], dtype='int64') return input_ids @@ -57,13 +55,13 @@ def parse_args(): ) parser.add_argument( "--decoding_strategy", - default='sampling', + default='beam_search', type=str, help= "The decoding strategy. Can be one of [greedy_search, beam_search, sampling]" ) parser.add_argument("--beam_size", - default=4, + default=5, type=int, help="The parameters for beam search. ") parser.add_argument( @@ -77,7 +75,7 @@ def parse_args(): type=float, help="The probability threshold to procedure topp sampling. ") parser.add_argument("--max_length", - default=50, + default=20, type=int, help="Maximum output length. ") parser.add_argument("--diversity_rate", @@ -103,6 +101,7 @@ def do_predict(args): logger.info('Loading the model parameters, please wait...') model = BartForConditionalGeneration.from_pretrained( args.model_name_or_path) + # Set evaluate mode model.eval() sentences = [ @@ -115,7 +114,7 @@ def do_predict(args): bos_id = model.bart.config['bos_token_id'] eos_id = model.bart.config['eos_token_id'] pad_id = model.bart.config['pad_token_id'] - input_ids = prepare_input(tokenizer, sentences, pad_id) + input_ids = prepare_input(tokenizer, sentences) # Define model faster_bart = model @@ -155,4 +154,5 @@ def do_predict(args): if __name__ == "__main__": args = parse_args() pprint(args) + do_predict(args) diff --git a/paddlenlp/ops/faster_transformer/sample/bart_export_model_sample.py b/paddlenlp/ops/faster_transformer/sample/bart_export_model_sample.py new file mode 100644 index 000000000000..82aeaa3aeec9 --- /dev/null +++ b/paddlenlp/ops/faster_transformer/sample/bart_export_model_sample.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import paddle +from pprint import pprint +from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer +from paddlenlp.ops import FasterBART +from paddlenlp.utils.log import logger + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name_or_path", + default="bart-base", + type=str, + help="The model name to specify the bart to use. ") + parser.add_argument("--inference_model_dir", + default="./infer_model/", + type=str, + help="Path to save inference model of bart. ") + parser.add_argument( + "--topk", + default=4, + type=int, + help="The number of candidate to procedure top_k sampling. ") + parser.add_argument( + "--topp", + default=1.0, + type=float, + help="The probability threshold to procedure top_p sampling. ") + parser.add_argument("--max_out_len", + default=20, + type=int, + help="Maximum output length. ") + parser.add_argument("--temperature", + default=1.0, + type=float, + help="The temperature to set. ") + parser.add_argument("--num_return_sequences", + default=1, + type=int, + help="The number of returned sequences. ") + parser.add_argument("--use_fp16_decoding", + action="store_true", + help="Whether to use fp16 decoding to predict. ") + parser.add_argument("--decoding_strategy", + default="beam_search", + choices=["sampling", "beam_search"], + type=str, + help="The main strategy to decode. ") + parser.add_argument( + "--num_beams", + default=5, + type=int, + help="The number of candidate to procedure beam search. ") + parser.add_argument("--diversity_rate", + default=0.0, + type=float, + help="The diversity rate to procedure beam search. ") + parser.add_argument("--repetition_penalty", + default=1.0, + type=float, + help="The repetition_penalty to set. ") + parser.add_argument("--length_penalty", + default=0.0, + type=float, + help="The length penalty to decode. ") + parser.add_argument("--early_stopping", + action="store_true", + help="Whether to do early stopping. ") + + args = parser.parse_args() + return args + + +def do_predict(args): + place = "gpu" + place = paddle.set_device(place) + + model = BartForConditionalGeneration.from_pretrained( + args.model_name_or_path) + tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path) + + # For opening faster_encoder + model.eval() + + faster_bart = FasterBART(model=model, + use_fp16_decoding=args.use_fp16_decoding) + # Set evaluate mode + faster_bart.eval() + + # Convert dygraph model to static graph model + faster_bart = paddle.jit.to_static( + faster_bart, + input_spec=[ + # input_ids + paddle.static.InputSpec(shape=[None, None], dtype="int32"), + # encoder_output + None, + # seq_len + None, + args.num_beams, # num_beams. + args.topk, + args.topp, + args.decoding_strategy, + tokenizer.bos_token_id, # bos + tokenizer.eos_token_id, # eos + tokenizer.pad_token_id, # pad + tokenizer.eos_token_id, # decoder_start_token_id + args.max_out_len, # max_length + args.diversity_rate, # diversity_rate + args.length_penalty, # length_penalty + args.num_return_sequences, + args.early_stopping, + tokenizer.eos_token_id, #forced_eos_token_id + ]) + + # Save converted static graph model + paddle.jit.save(faster_bart, os.path.join(args.inference_model_dir, "bart")) + logger.info("BART has been saved to {}.".format(args.inference_model_dir)) + + +if __name__ == "__main__": + args = parse_args() + pprint(args) + + do_predict(args) diff --git a/paddlenlp/ops/faster_transformer/sample/bart_inference.py b/paddlenlp/ops/faster_transformer/sample/bart_inference.py new file mode 100644 index 000000000000..9461397531e2 --- /dev/null +++ b/paddlenlp/ops/faster_transformer/sample/bart_inference.py @@ -0,0 +1,112 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import numpy as np +from pprint import pprint + +import paddle +import paddle.inference as paddle_infer + +from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer +from paddlenlp.ops.ext_utils import load + + +def setup_args(): + """Setup arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--inference_model_dir", + default="./infer_model/", + type=str, + help="Path to save inference model of BART. ") + + args = parser.parse_args() + + return args + + +def prepare_input(tokenizer, sentences): + tokenized = tokenizer(sentences, padding=True) + input_ids = np.asarray(tokenized['input_ids'], dtype="int32") + return input_ids + + +def postprocess_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): + """ + Post-process the decoded sequence. + """ + eos_pos = len(seq) - 1 + for i, idx in enumerate(seq): + if idx == eos_idx: + eos_pos = i + break + seq = [ + idx for idx in seq[:eos_pos + 1] + if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) + ] + return seq + + +def infer(args): + model_name = 'bart-base' + tokenizer = BartTokenizer.from_pretrained(model_name) + + sentences = [ + "I love that girl, but does not me.", + "She is so that I can not help glance at .", + "Nothing's gonna my love for you.", + "Drop everything now. Meet me in the pouring . Kiss me on the sidewalk.", + ] + + input_ids = prepare_input(tokenizer, sentences) + + # Load FasterTransformer lib. + load("FasterTransformer", verbose=True) + + config = paddle_infer.Config( + os.path.join(args.inference_model_dir, "bart.pdmodel"), + os.path.join(args.inference_model_dir, "bart.pdiparams")) + + config.enable_use_gpu(100, 0) + config.disable_glog_info() + # `embedding_eltwise_layernorm_fuse_pass` failed + config.delete_pass("embedding_eltwise_layernorm_fuse_pass") + predictor = paddle_infer.create_predictor(config) + + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + input_handle.copy_from_cpu(input_ids.astype("int32")) + + predictor.run() + + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + output_data = output_handle.copy_to_cpu() + + for idx, sample in enumerate(output_data.transpose([1, 2, 0]).tolist()): + for beam_idx, beam in enumerate(sample): + if beam_idx >= len(sample) / 2: + break + generated_ids = postprocess_seq(beam, tokenizer.bos_token_id, + tokenizer.eos_token_id) + seq = tokenizer.convert_ids_to_string(generated_ids) + print(f'{idx}-{beam_idx}: {seq}') + + +if __name__ == "__main__": + args = setup_args() + pprint(args) + + infer(args) diff --git a/paddlenlp/ops/faster_transformer/transformer/decoding.py b/paddlenlp/ops/faster_transformer/transformer/decoding.py index d4b9e74708f3..a5d4c8b04b82 100644 --- a/paddlenlp/ops/faster_transformer/transformer/decoding.py +++ b/paddlenlp/ops/faster_transformer/transformer/decoding.py @@ -1955,7 +1955,8 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False): self.pos_emb = [model.decoder.decoder_embed_positions.weight] self.word_emb = [model.decoder.embed_tokens.weight] - self.linear_weight = [model.lm_head_weight.t()] + setattr(self, "lm_head_weight_", model.lm_head_weight.t()) + self.linear_weight = [getattr(self, "lm_head_weight_")] self.linear_bias = [model.final_logits_bias] def forward(self, diff --git a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py index f4454603613f..dc5fdd8494f4 100644 --- a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py +++ b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py @@ -1172,8 +1172,13 @@ def forward(self, class FasterBART(BartPretrainedModel): + enable_faster_encoder_func = enable_faster_encoder - def __init__(self, model, decoding_lib=None, use_fp16_decoding=False): + def __init__(self, + model, + decoding_lib=None, + use_fp16_decoding=False, + enable_faster_encoder=True): super(FasterBART, self).__init__() self.use_fp16_decoding = use_fp16_decoding self._model = model @@ -1186,10 +1191,14 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False): self.encoder = model.bart.get_encoder() self.decoder = model.bart.get_decoder() self.pad_token_id = model.bart.config['pad_token_id'] + self.enable_faster_encoder = enable_faster_encoder self.decoding = InferBartDecoding(model=self._model, decoding_lib=decoding_lib, use_fp16_decoding=use_fp16_decoding) + if self.enable_faster_encoder: + # Must use `enable_faster_encoder` in `__init__` when dygraph to static graph. + self.encoder = FasterBART.enable_faster_encoder_func(self.encoder) def get_encoder(self): return self.encoder @@ -1218,11 +1227,9 @@ def forward(self, **model_kwargs): if encoder_output is None: - self.encoder = enable_faster_encoder(self.encoder) assert input_ids is not None, "You have to specify either input_ids or encoder_output." encoder_output = self.prepare_encoder_decoder_kwargs_for_generation( input_ids, model_kwargs)["encoder_output"] - self.encoder = disable_faster_encoder(self.encoder) if seq_len is None: assert input_ids is not None, "You have to specify either input_ids when generating seq_len." seq_len = paddle.sum(paddle.cast(input_ids != self.pad_token_id, diff --git a/paddlenlp/transformers/bart/modeling.py b/paddlenlp/transformers/bart/modeling.py index b2cb03f32c82..b874e7449adb 100644 --- a/paddlenlp/transformers/bart/modeling.py +++ b/paddlenlp/transformers/bart/modeling.py @@ -136,7 +136,8 @@ def forward(self, input_ids_shape, past_key_values_length=0): positions = paddle.arange(past_key_values_length, past_key_values_length + seq_len, dtype="int64") - return super().forward(positions + self.offset) + # (gongenlei) For dygraph to static graph + return Embedding.forward(self, positions + self.offset) class BartEncoder(BartPretrainedModel): @@ -200,7 +201,7 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs): if input_ids is None: raise ValueError("Input_ids cannot be None.") inputs_embeds = self.embed_tokens(input_ids) - inputs_embed_pos = self.encoder_embed_positions(input_ids.shape) + inputs_embed_pos = self.encoder_embed_positions(paddle.shape(input_ids)) hidden_states = inputs_embeds + inputs_embed_pos hidden_states = self.encoder_layernorm_embedding(hidden_states) encoder_input = self.encoder_dropout(hidden_states) @@ -298,7 +299,7 @@ def forward(self, past_key_values_length = paddle.shape( cache[0][0].k)[2] if cache is not None else 0 decoder_inputs_embed_pos = self.decoder_embed_positions( - decoder_input_ids.shape, past_key_values_length) + paddle.shape(decoder_input_ids), past_key_values_length) hidden_states = decoder_inputs_embeds + decoder_inputs_embed_pos hidden_states = self.decoder_layernorm_embedding(hidden_states) decoder_input = self.decoder_dropout(hidden_states) @@ -756,6 +757,8 @@ def prepare_faster_entry(self, kwargs): from paddlenlp.ops import FasterBART decode_strategy = kwargs.get('decode_strategy') use_fp16_decoding = kwargs.get('use_fp16_decoding', False) + decoding_lib = kwargs.get('decoding_lib', None) + enable_faster_encoder = kwargs.get('enable_faster_encoder', True) if decode_strategy == 'sampling' and kwargs.get( 'top_k') != 0 and kwargs.get('top_p') != 1: raise AttributeError( @@ -776,7 +779,10 @@ def prepare_faster_entry(self, kwargs): "'forced_bos_token_id != None' is not supported yet in the faster version" ) self._faster_entry = FasterBART( - self, use_fp16_decoding=use_fp16_decoding).forward + self, + use_fp16_decoding=use_fp16_decoding, + decoding_lib=decoding_lib, + enable_faster_encoder=enable_faster_encoder).forward return self._faster_entry def forward(self,