Skip to content

Commit

Permalink
[FasterBART] Suppport dygraph to static and inference (#2519)
Browse files Browse the repository at this point in the history
* suppport d2s and inference

* default to beam_search

* add comments

* update tokenizer usage

* rm disable_faster_encoder
  • Loading branch information
gongenlei authored Jun 17, 2022
1 parent 5901915 commit 222fc31
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 17 deletions.
18 changes: 9 additions & 9 deletions paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@ def postprocess_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
return seq


def prepare_input(tokenizer, sentences, pad_id):
word_pad = Pad(pad_id, dtype="int64")
tokenized = tokenizer(sentences, return_length=True)
inputs = word_pad([i["input_ids"] for i in tokenized])
input_ids = paddle.to_tensor(inputs)
def prepare_input(tokenizer, sentences):
tokenized = tokenizer(sentences, padding=True)
input_ids = paddle.to_tensor(tokenized['input_ids'], dtype='int64')
return input_ids


Expand All @@ -57,13 +55,13 @@ def parse_args():
)
parser.add_argument(
"--decoding_strategy",
default='sampling',
default='beam_search',
type=str,
help=
"The decoding strategy. Can be one of [greedy_search, beam_search, sampling]"
)
parser.add_argument("--beam_size",
default=4,
default=5,
type=int,
help="The parameters for beam search. ")
parser.add_argument(
Expand All @@ -77,7 +75,7 @@ def parse_args():
type=float,
help="The probability threshold to procedure topp sampling. ")
parser.add_argument("--max_length",
default=50,
default=20,
type=int,
help="Maximum output length. ")
parser.add_argument("--diversity_rate",
Expand All @@ -103,6 +101,7 @@ def do_predict(args):
logger.info('Loading the model parameters, please wait...')
model = BartForConditionalGeneration.from_pretrained(
args.model_name_or_path)

# Set evaluate mode
model.eval()
sentences = [
Expand All @@ -115,7 +114,7 @@ def do_predict(args):
bos_id = model.bart.config['bos_token_id']
eos_id = model.bart.config['eos_token_id']
pad_id = model.bart.config['pad_token_id']
input_ids = prepare_input(tokenizer, sentences, pad_id)
input_ids = prepare_input(tokenizer, sentences)
# Define model
faster_bart = model

Expand Down Expand Up @@ -155,4 +154,5 @@ def do_predict(args):
if __name__ == "__main__":
args = parse_args()
pprint(args)

do_predict(args)
140 changes: 140 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/bart_export_model_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import paddle
from pprint import pprint
from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer
from paddlenlp.ops import FasterBART
from paddlenlp.utils.log import logger


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path",
default="bart-base",
type=str,
help="The model name to specify the bart to use. ")
parser.add_argument("--inference_model_dir",
default="./infer_model/",
type=str,
help="Path to save inference model of bart. ")
parser.add_argument(
"--topk",
default=4,
type=int,
help="The number of candidate to procedure top_k sampling. ")
parser.add_argument(
"--topp",
default=1.0,
type=float,
help="The probability threshold to procedure top_p sampling. ")
parser.add_argument("--max_out_len",
default=20,
type=int,
help="Maximum output length. ")
parser.add_argument("--temperature",
default=1.0,
type=float,
help="The temperature to set. ")
parser.add_argument("--num_return_sequences",
default=1,
type=int,
help="The number of returned sequences. ")
parser.add_argument("--use_fp16_decoding",
action="store_true",
help="Whether to use fp16 decoding to predict. ")
parser.add_argument("--decoding_strategy",
default="beam_search",
choices=["sampling", "beam_search"],
type=str,
help="The main strategy to decode. ")
parser.add_argument(
"--num_beams",
default=5,
type=int,
help="The number of candidate to procedure beam search. ")
parser.add_argument("--diversity_rate",
default=0.0,
type=float,
help="The diversity rate to procedure beam search. ")
parser.add_argument("--repetition_penalty",
default=1.0,
type=float,
help="The repetition_penalty to set. ")
parser.add_argument("--length_penalty",
default=0.0,
type=float,
help="The length penalty to decode. ")
parser.add_argument("--early_stopping",
action="store_true",
help="Whether to do early stopping. ")

args = parser.parse_args()
return args


def do_predict(args):
place = "gpu"
place = paddle.set_device(place)

model = BartForConditionalGeneration.from_pretrained(
args.model_name_or_path)
tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)

# For opening faster_encoder
model.eval()

faster_bart = FasterBART(model=model,
use_fp16_decoding=args.use_fp16_decoding)
# Set evaluate mode
faster_bart.eval()

# Convert dygraph model to static graph model
faster_bart = paddle.jit.to_static(
faster_bart,
input_spec=[
# input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int32"),
# encoder_output
None,
# seq_len
None,
args.num_beams, # num_beams.
args.topk,
args.topp,
args.decoding_strategy,
tokenizer.bos_token_id, # bos
tokenizer.eos_token_id, # eos
tokenizer.pad_token_id, # pad
tokenizer.eos_token_id, # decoder_start_token_id
args.max_out_len, # max_length
args.diversity_rate, # diversity_rate
args.length_penalty, # length_penalty
args.num_return_sequences,
args.early_stopping,
tokenizer.eos_token_id, #forced_eos_token_id
])

# Save converted static graph model
paddle.jit.save(faster_bart, os.path.join(args.inference_model_dir, "bart"))
logger.info("BART has been saved to {}.".format(args.inference_model_dir))


if __name__ == "__main__":
args = parse_args()
pprint(args)

do_predict(args)
112 changes: 112 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/bart_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import numpy as np
from pprint import pprint

import paddle
import paddle.inference as paddle_infer

from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer
from paddlenlp.ops.ext_utils import load


def setup_args():
"""Setup arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--inference_model_dir",
default="./infer_model/",
type=str,
help="Path to save inference model of BART. ")

args = parser.parse_args()

return args


def prepare_input(tokenizer, sentences):
tokenized = tokenizer(sentences, padding=True)
input_ids = np.asarray(tokenized['input_ids'], dtype="int32")
return input_ids


def postprocess_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq


def infer(args):
model_name = 'bart-base'
tokenizer = BartTokenizer.from_pretrained(model_name)

sentences = [
"I love that girl, but <mask> does not <mask> me.",
"She is so <mask> that I can not help glance at <mask>.",
"Nothing's gonna <mask> my love for you.",
"Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.",
]

input_ids = prepare_input(tokenizer, sentences)

# Load FasterTransformer lib.
load("FasterTransformer", verbose=True)

config = paddle_infer.Config(
os.path.join(args.inference_model_dir, "bart.pdmodel"),
os.path.join(args.inference_model_dir, "bart.pdiparams"))

config.enable_use_gpu(100, 0)
config.disable_glog_info()
# `embedding_eltwise_layernorm_fuse_pass` failed
config.delete_pass("embedding_eltwise_layernorm_fuse_pass")
predictor = paddle_infer.create_predictor(config)

input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(input_ids.astype("int32"))

predictor.run()

output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()

for idx, sample in enumerate(output_data.transpose([1, 2, 0]).tolist()):
for beam_idx, beam in enumerate(sample):
if beam_idx >= len(sample) / 2:
break
generated_ids = postprocess_seq(beam, tokenizer.bos_token_id,
tokenizer.eos_token_id)
seq = tokenizer.convert_ids_to_string(generated_ids)
print(f'{idx}-{beam_idx}: {seq}')


if __name__ == "__main__":
args = setup_args()
pprint(args)

infer(args)
3 changes: 2 additions & 1 deletion paddlenlp/ops/faster_transformer/transformer/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,7 +1955,8 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
self.pos_emb = [model.decoder.decoder_embed_positions.weight]
self.word_emb = [model.decoder.embed_tokens.weight]

self.linear_weight = [model.lm_head_weight.t()]
setattr(self, "lm_head_weight_", model.lm_head_weight.t())
self.linear_weight = [getattr(self, "lm_head_weight_")]
self.linear_bias = [model.final_logits_bias]

def forward(self,
Expand Down
13 changes: 10 additions & 3 deletions paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,8 +1172,13 @@ def forward(self,


class FasterBART(BartPretrainedModel):
enable_faster_encoder_func = enable_faster_encoder

def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
def __init__(self,
model,
decoding_lib=None,
use_fp16_decoding=False,
enable_faster_encoder=True):
super(FasterBART, self).__init__()
self.use_fp16_decoding = use_fp16_decoding
self._model = model
Expand All @@ -1186,10 +1191,14 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
self.encoder = model.bart.get_encoder()
self.decoder = model.bart.get_decoder()
self.pad_token_id = model.bart.config['pad_token_id']
self.enable_faster_encoder = enable_faster_encoder

self.decoding = InferBartDecoding(model=self._model,
decoding_lib=decoding_lib,
use_fp16_decoding=use_fp16_decoding)
if self.enable_faster_encoder:
# Must use `enable_faster_encoder` in `__init__` when dygraph to static graph.
self.encoder = FasterBART.enable_faster_encoder_func(self.encoder)

def get_encoder(self):
return self.encoder
Expand Down Expand Up @@ -1218,11 +1227,9 @@ def forward(self,
**model_kwargs):

if encoder_output is None:
self.encoder = enable_faster_encoder(self.encoder)
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
self.encoder = disable_faster_encoder(self.encoder)
if seq_len is None:
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
seq_len = paddle.sum(paddle.cast(input_ids != self.pad_token_id,
Expand Down
Loading

0 comments on commit 222fc31

Please sign in to comment.