Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FasterBART] Suppport dygraph to static and inference #2519

Merged
merged 9 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@ def postprocess_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
return seq


def prepare_input(tokenizer, sentences, pad_id):
word_pad = Pad(pad_id, dtype="int64")
tokenized = tokenizer(sentences, return_length=True)
inputs = word_pad([i["input_ids"] for i in tokenized])
input_ids = paddle.to_tensor(inputs)
def prepare_input(tokenizer, sentences):
tokenized = tokenizer(sentences, padding=True)
input_ids = paddle.to_tensor(tokenized['input_ids'], dtype='int64')
return input_ids


Expand All @@ -57,13 +55,13 @@ def parse_args():
)
parser.add_argument(
"--decoding_strategy",
default='sampling',
default='beam_search',
type=str,
help=
"The decoding strategy. Can be one of [greedy_search, beam_search, sampling]"
)
parser.add_argument("--beam_size",
default=4,
default=5,
type=int,
help="The parameters for beam search. ")
parser.add_argument(
Expand All @@ -77,7 +75,7 @@ def parse_args():
type=float,
help="The probability threshold to procedure topp sampling. ")
parser.add_argument("--max_length",
default=50,
default=20,
type=int,
help="Maximum output length. ")
parser.add_argument("--diversity_rate",
Expand All @@ -103,6 +101,7 @@ def do_predict(args):
logger.info('Loading the model parameters, please wait...')
model = BartForConditionalGeneration.from_pretrained(
args.model_name_or_path)

# Set evaluate mode
model.eval()
sentences = [
Expand All @@ -115,7 +114,7 @@ def do_predict(args):
bos_id = model.bart.config['bos_token_id']
eos_id = model.bart.config['eos_token_id']
pad_id = model.bart.config['pad_token_id']
input_ids = prepare_input(tokenizer, sentences, pad_id)
input_ids = prepare_input(tokenizer, sentences)
# Define model
faster_bart = model

Expand Down Expand Up @@ -155,4 +154,5 @@ def do_predict(args):
if __name__ == "__main__":
args = parse_args()
pprint(args)

do_predict(args)
140 changes: 140 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/bart_export_model_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import paddle
from pprint import pprint
from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer
from paddlenlp.ops import FasterBART
from paddlenlp.utils.log import logger


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path",
default="bart-base",
type=str,
help="The model name to specify the bart to use. ")
parser.add_argument("--inference_model_dir",
default="./infer_model/",
type=str,
help="Path to save inference model of bart. ")
parser.add_argument(
"--topk",
default=4,
type=int,
help="The number of candidate to procedure top_k sampling. ")
parser.add_argument(
"--topp",
default=1.0,
type=float,
help="The probability threshold to procedure top_p sampling. ")
parser.add_argument("--max_out_len",
default=20,
type=int,
help="Maximum output length. ")
parser.add_argument("--temperature",
default=1.0,
type=float,
help="The temperature to set. ")
parser.add_argument("--num_return_sequences",
default=1,
type=int,
help="The number of returned sequences. ")
parser.add_argument("--use_fp16_decoding",
action="store_true",
help="Whether to use fp16 decoding to predict. ")
parser.add_argument("--decoding_strategy",
default="beam_search",
choices=["sampling", "beam_search"],
type=str,
help="The main strategy to decode. ")
parser.add_argument(
"--num_beams",
default=5,
type=int,
help="The number of candidate to procedure beam search. ")
parser.add_argument("--diversity_rate",
default=0.0,
type=float,
help="The diversity rate to procedure beam search. ")
parser.add_argument("--repetition_penalty",
default=1.0,
type=float,
help="The repetition_penalty to set. ")
parser.add_argument("--length_penalty",
default=0.0,
type=float,
help="The length penalty to decode. ")
parser.add_argument("--early_stopping",
action="store_true",
help="Whether to do early stopping. ")

args = parser.parse_args()
return args


def do_predict(args):
place = "gpu"
place = paddle.set_device(place)

model = BartForConditionalGeneration.from_pretrained(
args.model_name_or_path)
tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)

# For opening faster_encoder
model.eval()

faster_bart = FasterBART(model=model,
use_fp16_decoding=args.use_fp16_decoding)
# Set evaluate mode
faster_bart.eval()

# Convert dygraph model to static graph model
faster_bart = paddle.jit.to_static(
faster_bart,
input_spec=[
# input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int32"),
# encoder_output
None,
# seq_len
None,
args.num_beams, # num_beams.
args.topk,
args.topp,
args.decoding_strategy,
tokenizer.bos_token_id, # bos
tokenizer.eos_token_id, # eos
tokenizer.pad_token_id, # pad
tokenizer.eos_token_id, # decoder_start_token_id
args.max_out_len, # max_length
args.diversity_rate, # diversity_rate
args.length_penalty, # length_penalty
args.num_return_sequences,
args.early_stopping,
tokenizer.eos_token_id, #forced_eos_token_id
])

# Save converted static graph model
paddle.jit.save(faster_bart, os.path.join(args.inference_model_dir, "bart"))
logger.info("BART has been saved to {}.".format(args.inference_model_dir))


if __name__ == "__main__":
args = parse_args()
pprint(args)

do_predict(args)
112 changes: 112 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/bart_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import numpy as np
from pprint import pprint

import paddle
import paddle.inference as paddle_infer

from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer
from paddlenlp.ops.ext_utils import load


def setup_args():
"""Setup arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--inference_model_dir",
default="./infer_model/",
type=str,
help="Path to save inference model of BART. ")

args = parser.parse_args()

return args


def prepare_input(tokenizer, sentences):
tokenized = tokenizer(sentences, padding=True)
input_ids = np.asarray(tokenized['input_ids'], dtype="int32")
return input_ids


def postprocess_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq


def infer(args):
model_name = 'bart-base'
tokenizer = BartTokenizer.from_pretrained(model_name)

sentences = [
"I love that girl, but <mask> does not <mask> me.",
"She is so <mask> that I can not help glance at <mask>.",
"Nothing's gonna <mask> my love for you.",
"Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.",
]

input_ids = prepare_input(tokenizer, sentences)

# Load FasterTransformer lib.
load("FasterTransformer", verbose=True)

config = paddle_infer.Config(
os.path.join(args.inference_model_dir, "bart.pdmodel"),
os.path.join(args.inference_model_dir, "bart.pdiparams"))

config.enable_use_gpu(100, 0)
config.disable_glog_info()
# `embedding_eltwise_layernorm_fuse_pass` failed
config.delete_pass("embedding_eltwise_layernorm_fuse_pass")
predictor = paddle_infer.create_predictor(config)

input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(input_ids.astype("int32"))

predictor.run()

output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()

for idx, sample in enumerate(output_data.transpose([1, 2, 0]).tolist()):
for beam_idx, beam in enumerate(sample):
if beam_idx >= len(sample) / 2:
break
generated_ids = postprocess_seq(beam, tokenizer.bos_token_id,
tokenizer.eos_token_id)
seq = tokenizer.convert_ids_to_string(generated_ids)
print(f'{idx}-{beam_idx}: {seq}')


if __name__ == "__main__":
args = setup_args()
pprint(args)

infer(args)
3 changes: 2 additions & 1 deletion paddlenlp/ops/faster_transformer/transformer/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,7 +1955,8 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
self.pos_emb = [model.decoder.decoder_embed_positions.weight]
self.word_emb = [model.decoder.embed_tokens.weight]

self.linear_weight = [model.lm_head_weight.t()]
setattr(self, "lm_head_weight_", model.lm_head_weight.t())
self.linear_weight = [getattr(self, "lm_head_weight_")]
self.linear_bias = [model.final_logits_bias]

def forward(self,
Expand Down
16 changes: 13 additions & 3 deletions paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,8 +1172,13 @@ def forward(self,


class FasterBART(BartPretrainedModel):
enable_faster_encoder_func = enable_faster_encoder

def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
def __init__(self,
model,
decoding_lib=None,
use_fp16_decoding=False,
enable_faster_encoder=True):
super(FasterBART, self).__init__()
self.use_fp16_decoding = use_fp16_decoding
self._model = model
Expand All @@ -1186,10 +1191,15 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
self.encoder = model.bart.get_encoder()
self.decoder = model.bart.get_decoder()
self.pad_token_id = model.bart.config['pad_token_id']
self.enable_faster_encoder = enable_faster_encoder

self.decoding = InferBartDecoding(model=self._model,
decoding_lib=decoding_lib,
use_fp16_decoding=use_fp16_decoding)
if self.enable_faster_encoder:
# (gongenlei) Need to use `enable_faster_encoder` in `__init__`
# when dygraph to static graph.
self.encoder = FasterBART.enable_faster_encoder_func(self.encoder)

def get_encoder(self):
return self.encoder
Expand Down Expand Up @@ -1218,11 +1228,11 @@ def forward(self,
**model_kwargs):

if encoder_output is None:
self.encoder = enable_faster_encoder(self.encoder)
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
self.encoder = disable_faster_encoder(self.encoder)
if self.enable_faster_encoder:
self.encoder = disable_faster_encoder(self.encoder)
if seq_len is None:
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
seq_len = paddle.sum(paddle.cast(input_ids != self.pad_token_id,
Expand Down
Loading