-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FasterBART] Suppport dygraph to static and inference (#2519)
* suppport d2s and inference * default to beam_search * add comments * update tokenizer usage * rm disable_faster_encoder
- Loading branch information
gongenlei
authored
Jun 17, 2022
1 parent
5901915
commit 222fc31
Showing
6 changed files
with
283 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
paddlenlp/ops/faster_transformer/sample/bart_export_model_sample.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import argparse | ||
import paddle | ||
from pprint import pprint | ||
from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer | ||
from paddlenlp.ops import FasterBART | ||
from paddlenlp.utils.log import logger | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_name_or_path", | ||
default="bart-base", | ||
type=str, | ||
help="The model name to specify the bart to use. ") | ||
parser.add_argument("--inference_model_dir", | ||
default="./infer_model/", | ||
type=str, | ||
help="Path to save inference model of bart. ") | ||
parser.add_argument( | ||
"--topk", | ||
default=4, | ||
type=int, | ||
help="The number of candidate to procedure top_k sampling. ") | ||
parser.add_argument( | ||
"--topp", | ||
default=1.0, | ||
type=float, | ||
help="The probability threshold to procedure top_p sampling. ") | ||
parser.add_argument("--max_out_len", | ||
default=20, | ||
type=int, | ||
help="Maximum output length. ") | ||
parser.add_argument("--temperature", | ||
default=1.0, | ||
type=float, | ||
help="The temperature to set. ") | ||
parser.add_argument("--num_return_sequences", | ||
default=1, | ||
type=int, | ||
help="The number of returned sequences. ") | ||
parser.add_argument("--use_fp16_decoding", | ||
action="store_true", | ||
help="Whether to use fp16 decoding to predict. ") | ||
parser.add_argument("--decoding_strategy", | ||
default="beam_search", | ||
choices=["sampling", "beam_search"], | ||
type=str, | ||
help="The main strategy to decode. ") | ||
parser.add_argument( | ||
"--num_beams", | ||
default=5, | ||
type=int, | ||
help="The number of candidate to procedure beam search. ") | ||
parser.add_argument("--diversity_rate", | ||
default=0.0, | ||
type=float, | ||
help="The diversity rate to procedure beam search. ") | ||
parser.add_argument("--repetition_penalty", | ||
default=1.0, | ||
type=float, | ||
help="The repetition_penalty to set. ") | ||
parser.add_argument("--length_penalty", | ||
default=0.0, | ||
type=float, | ||
help="The length penalty to decode. ") | ||
parser.add_argument("--early_stopping", | ||
action="store_true", | ||
help="Whether to do early stopping. ") | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def do_predict(args): | ||
place = "gpu" | ||
place = paddle.set_device(place) | ||
|
||
model = BartForConditionalGeneration.from_pretrained( | ||
args.model_name_or_path) | ||
tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path) | ||
|
||
# For opening faster_encoder | ||
model.eval() | ||
|
||
faster_bart = FasterBART(model=model, | ||
use_fp16_decoding=args.use_fp16_decoding) | ||
# Set evaluate mode | ||
faster_bart.eval() | ||
|
||
# Convert dygraph model to static graph model | ||
faster_bart = paddle.jit.to_static( | ||
faster_bart, | ||
input_spec=[ | ||
# input_ids | ||
paddle.static.InputSpec(shape=[None, None], dtype="int32"), | ||
# encoder_output | ||
None, | ||
# seq_len | ||
None, | ||
args.num_beams, # num_beams. | ||
args.topk, | ||
args.topp, | ||
args.decoding_strategy, | ||
tokenizer.bos_token_id, # bos | ||
tokenizer.eos_token_id, # eos | ||
tokenizer.pad_token_id, # pad | ||
tokenizer.eos_token_id, # decoder_start_token_id | ||
args.max_out_len, # max_length | ||
args.diversity_rate, # diversity_rate | ||
args.length_penalty, # length_penalty | ||
args.num_return_sequences, | ||
args.early_stopping, | ||
tokenizer.eos_token_id, #forced_eos_token_id | ||
]) | ||
|
||
# Save converted static graph model | ||
paddle.jit.save(faster_bart, os.path.join(args.inference_model_dir, "bart")) | ||
logger.info("BART has been saved to {}.".format(args.inference_model_dir)) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
pprint(args) | ||
|
||
do_predict(args) |
112 changes: 112 additions & 0 deletions
112
paddlenlp/ops/faster_transformer/sample/bart_inference.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
import numpy as np | ||
from pprint import pprint | ||
|
||
import paddle | ||
import paddle.inference as paddle_infer | ||
|
||
from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer | ||
from paddlenlp.ops.ext_utils import load | ||
|
||
|
||
def setup_args(): | ||
"""Setup arguments.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--inference_model_dir", | ||
default="./infer_model/", | ||
type=str, | ||
help="Path to save inference model of BART. ") | ||
|
||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def prepare_input(tokenizer, sentences): | ||
tokenized = tokenizer(sentences, padding=True) | ||
input_ids = np.asarray(tokenized['input_ids'], dtype="int32") | ||
return input_ids | ||
|
||
|
||
def postprocess_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): | ||
""" | ||
Post-process the decoded sequence. | ||
""" | ||
eos_pos = len(seq) - 1 | ||
for i, idx in enumerate(seq): | ||
if idx == eos_idx: | ||
eos_pos = i | ||
break | ||
seq = [ | ||
idx for idx in seq[:eos_pos + 1] | ||
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) | ||
] | ||
return seq | ||
|
||
|
||
def infer(args): | ||
model_name = 'bart-base' | ||
tokenizer = BartTokenizer.from_pretrained(model_name) | ||
|
||
sentences = [ | ||
"I love that girl, but <mask> does not <mask> me.", | ||
"She is so <mask> that I can not help glance at <mask>.", | ||
"Nothing's gonna <mask> my love for you.", | ||
"Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.", | ||
] | ||
|
||
input_ids = prepare_input(tokenizer, sentences) | ||
|
||
# Load FasterTransformer lib. | ||
load("FasterTransformer", verbose=True) | ||
|
||
config = paddle_infer.Config( | ||
os.path.join(args.inference_model_dir, "bart.pdmodel"), | ||
os.path.join(args.inference_model_dir, "bart.pdiparams")) | ||
|
||
config.enable_use_gpu(100, 0) | ||
config.disable_glog_info() | ||
# `embedding_eltwise_layernorm_fuse_pass` failed | ||
config.delete_pass("embedding_eltwise_layernorm_fuse_pass") | ||
predictor = paddle_infer.create_predictor(config) | ||
|
||
input_names = predictor.get_input_names() | ||
input_handle = predictor.get_input_handle(input_names[0]) | ||
input_handle.copy_from_cpu(input_ids.astype("int32")) | ||
|
||
predictor.run() | ||
|
||
output_names = predictor.get_output_names() | ||
output_handle = predictor.get_output_handle(output_names[0]) | ||
output_data = output_handle.copy_to_cpu() | ||
|
||
for idx, sample in enumerate(output_data.transpose([1, 2, 0]).tolist()): | ||
for beam_idx, beam in enumerate(sample): | ||
if beam_idx >= len(sample) / 2: | ||
break | ||
generated_ids = postprocess_seq(beam, tokenizer.bos_token_id, | ||
tokenizer.eos_token_id) | ||
seq = tokenizer.convert_ids_to_string(generated_ids) | ||
print(f'{idx}-{beam_idx}: {seq}') | ||
|
||
|
||
if __name__ == "__main__": | ||
args = setup_args() | ||
pprint(args) | ||
|
||
infer(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.