Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lift restrictions to support more d_inner_hid #3592

Merged
merged 5 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import numpy as np
Expand Down Expand Up @@ -178,6 +192,11 @@ def do_predict(args):
transformer.load(init_from_params=os.path.join(args.init_from_params,
"transformer.pdparams"))

# Providing model_dict still works.
# state_dict = paddle.load(os.path.join(args.init_from_params,
# "transformer.pdparams"))
# transformer.load(state_dict=state_dict)

f = open(args.output_file, "w")
with paddle.no_grad():
if args.profile:
Expand Down
5 changes: 5 additions & 0 deletions examples/machine_translation/transformer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def do_predict(args):
transformer.load(os.path.join(args.init_from_params,
"transformer.pdparams"))

# Providing model_dict still works.
# state_dict = paddle.load(os.path.join(args.init_from_params,
# "transformer.pdparams"))
# transformer.load(state_dict=state_dict)

# Set evaluate mode
transformer.eval()

Expand Down
73 changes: 54 additions & 19 deletions paddlenlp/ops/faster_transformer/src/fusion_decoding_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>

#include <algorithm>
#include <iterator>
#include <random>
Expand Down Expand Up @@ -124,6 +125,8 @@ std::vector<paddle::Tensor> decoding_kernel(
DecoderInitParam<DataType_>* params =
new DecoderInitParam<DataType_>[num_layer_];

int inner_coeff = ffn_intermediate_weight[0].shape()[1] / memory_hidden_dim;

auto q_weight_shape = self_attn_query_weight[0].shape();
auto k_weight_shape = self_attn_key_weight[0].shape();
bool fuse_qkv = (q_weight_shape[1] == k_weight_shape[1]) ? false : true;
Expand Down Expand Up @@ -265,7 +268,19 @@ std::vector<paddle::Tensor> decoding_kernel(
end_id_,
beam_search_diversity_rate_,
true, // is_fuse_topk_softMax
fuse_qkv);
fuse_qkv,
false, // keep_alive_beam
0.6, // alpha
true, // normalization_before
0, // pos_offset
ActivationType::RELU, // act
false, // pos_bias
false, // prefix_lm
-1, // finished_candidate_num
false, // early_stopping
false, // is_mbart
0, // min_length
inner_coeff);

decoding_beam_search_->forward(params, decoding_params);

Expand All @@ -286,10 +301,20 @@ std::vector<paddle::Tensor> decoding_kernel(
start_id_,
end_id_,
beam_search_diversity_rate_,
true, // is_fuse_topk_softMax
true, // is_fuse_topk_softMax
fuse_qkv,
true, // keep_alive_beam
alpha);
true, // keep_alive_beam
alpha,
true, // normalization_before
0, // pos_offset
ActivationType::RELU, // act
false, // pos_bias
false, // prefix_lm
-1, // finished_candidate_num
false, // early_stopping
false, // is_mbart
0, // min_length
inner_coeff);

decoding_beam_search_->forward(params, decoding_params);

Expand All @@ -298,21 +323,31 @@ std::vector<paddle::Tensor> decoding_kernel(
"topp_sampling" == decoding_strategy ||
"sampling" == decoding_strategy) {
DecodingSampling<DecodingTraits_::OpType>* decoding_sampling_;
decoding_sampling_ =
new DecodingSampling<DecodingTraits_::OpType>(allocator_,
batch_size_,
max_seq_len_,
head_num_,
size_per_head_,
vocab_size,
num_layer_,
memory_hidden_dim,
memory_max_seq_len,
start_id_,
end_id_,
candidate_num_,
probability_threshold_,
fuse_qkv);
decoding_sampling_ = new DecodingSampling<DecodingTraits_::OpType>(
allocator_,
batch_size_,
max_seq_len_,
head_num_,
size_per_head_,
vocab_size,
num_layer_,
memory_hidden_dim,
memory_max_seq_len,
start_id_,
end_id_,
candidate_num_,
probability_threshold_,
fuse_qkv,
true, // normalization_before
0, // pos_offset
ActivationType::RELU, // act
false, // pos_bias
1.0, // temperature
1.0, // repeat_penalty
false, // prefix_lm
false, // is_mbart
0, // min_length
inner_coeff);

decoding_sampling_->forward(params, decoding_params);

Expand Down
73 changes: 54 additions & 19 deletions paddlenlp/ops/faster_transformer/src/fusion_force_decoding_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>

#include <algorithm>
#include <iterator>
#include <random>
Expand Down Expand Up @@ -155,6 +156,8 @@ std::vector<paddle::Tensor> decoding_kernel(
DecoderInitParam<DataType_>* params =
new DecoderInitParam<DataType_>[num_layer_];

int inner_coeff = ffn_intermediate_weight[0].shape()[1] / memory_hidden_dim;

auto q_weight_shape = self_attn_query_weight[0].shape();
auto k_weight_shape = self_attn_key_weight[0].shape();
bool fuse_qkv = (q_weight_shape[1] == k_weight_shape[1]) ? false : true;
Expand Down Expand Up @@ -296,7 +299,19 @@ std::vector<paddle::Tensor> decoding_kernel(
end_id_,
beam_search_diversity_rate_,
true, // is_fuse_topk_softMax
fuse_qkv); // is_fuse_qkv
fuse_qkv,
false, // keep_alive_beam
0.6, // alpha
true, // normalization_before
0, // pos_offset
ActivationType::RELU, // act
false, // pos_bias
false, // prefix_lm
-1, // finished_candidate_num
false, // early_stopping
false, // is_mbart
0, // min_length
inner_coeff);

decoding_beam_search_->forward(params, decoding_params);

Expand All @@ -317,10 +332,20 @@ std::vector<paddle::Tensor> decoding_kernel(
start_id_,
end_id_,
beam_search_diversity_rate_,
true, // is_fuse_topk_softMax
true, // is_fuse_topk_softMax
fuse_qkv, // is_fuse_qkv
true, // keep_alive_beam
alpha);
true, // keep_alive_beam
alpha,
true, // normalization_before
0, // pos_offset
ActivationType::RELU, // act
false, // pos_bias
false, // prefix_lm
-1, // finished_candidate_num
false, // early_stopping
false, // is_mbart
0, // min_length
inner_coeff);

decoding_beam_search_->forward(params, decoding_params);

Expand All @@ -329,21 +354,31 @@ std::vector<paddle::Tensor> decoding_kernel(
"topp_sampling" == decoding_strategy ||
"sampling" == decoding_strategy) {
DecodingSampling<DecodingTraits_::OpType>* decoding_sampling_;
decoding_sampling_ =
new DecodingSampling<DecodingTraits_::OpType>(allocator_,
batch_size_,
max_seq_len_,
head_num_,
size_per_head_,
vocab_size,
num_layer_,
memory_hidden_dim,
memory_max_seq_len,
start_id_,
end_id_,
candidate_num_,
probability_threshold_,
fuse_qkv);
decoding_sampling_ = new DecodingSampling<DecodingTraits_::OpType>(
allocator_,
batch_size_,
max_seq_len_,
head_num_,
size_per_head_,
vocab_size,
num_layer_,
memory_hidden_dim,
memory_max_seq_len,
start_id_,
end_id_,
candidate_num_,
probability_threshold_,
fuse_qkv,
true, // normalization_before
0, // pos_offset
ActivationType::RELU, // act
false, // pos_bias
1.0, // temperature
1.0, // repeat_penalty
false, // prefix_lm
false, // is_mbart
0, // min_length
inner_coeff);

decoding_sampling_->forward(params, decoding_params);

Expand Down
71 changes: 39 additions & 32 deletions paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,45 +249,46 @@ def forward(self, src_word, trg_word=None):

return ids

def load(self, init_from_params):
def load(self, init_from_params=None, state_dict=None):
# Load the trained model
assert init_from_params, (
"Please set init_from_params to load the infer model.")
if init_from_params is None and state_dict is None:
raise ValueError(
"Either init_from_params or state_dict must be given to load the infer model. "
)

model_dict = paddle.load(init_from_params, return_numpy=True)
if state_dict is None:
state_dict = paddle.load(init_from_params, return_numpy=True)
else:
for state in state_dict:
# NOTE: This API only used in dygraph, so paddle.Tensor is enough.
if isinstance(state_dict[state], paddle.Tensor):
state_dict[state] = state_dict[state].numpy()

# To set weight[padding_idx] to 0.
model_dict["trg_word_embedding.word_embedding.weight"][
state_dict["trg_word_embedding.word_embedding.weight"][
self.bos_id] = [0] * self.d_model

# Dealing with weight sharing.
if self.weight_sharing:
model_dict["decoding_linear.weight"] = np.transpose(
model_dict["trg_word_embedding.word_embedding.weight"])
state_dict["decoding_linear.weight"] = np.transpose(
state_dict["trg_word_embedding.word_embedding.weight"])
else:
model_dict["decoding_linear.weight"] = model_dict["linear.weight"]

# To avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
self.max_length, self.d_model)
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
self.max_length, self.d_model)
state_dict["decoding_linear.weight"] = state_dict["linear.weight"]

if self.decoding._fuse_qkv:
for item in self.state_dict():
if "decoder" in item and "self_attn.q_proj" in item:
num_layer = item.split(".")[3]
param_type = item.split(".")[-1]

model_dict["decoding.slf_q_" + param_type + "_" +
state_dict["decoding.slf_q_" + param_type + "_" +
num_layer] = np.concatenate(
(model_dict[item],
model_dict["transformer.decoder.layers." +
(state_dict[item],
state_dict["transformer.decoder.layers." +
num_layer +
".self_attn.k_proj." +
param_type],
model_dict["transformer.decoder.layers." +
state_dict["transformer.decoder.layers." +
num_layer +
".self_attn.v_proj." +
param_type]),
Expand All @@ -296,17 +297,17 @@ def load(self, init_from_params):
if self.use_fp16_decoding:
for item in self.state_dict():
if "decoder" in item or "decoding.slf" in item:
model_dict[item] = np.float16(model_dict[item])
model_dict["decoding_linear.weight"] = np.float16(
model_dict["decoding_linear.weight"])
model_dict["trg_word_embedding.word_embedding.weight"] = np.float16(
model_dict["trg_word_embedding.word_embedding.weight"])
model_dict["trg_pos_embedding.pos_encoder.weight"] = np.float16(
model_dict["trg_pos_embedding.pos_encoder.weight"])
model_dict["decoding_linear.bias"] = np.zeros([self.trg_vocab_size],
state_dict[item] = np.float16(state_dict[item])
state_dict["decoding_linear.weight"] = np.float16(
state_dict["decoding_linear.weight"])
state_dict["trg_word_embedding.word_embedding.weight"] = np.float16(
state_dict["trg_word_embedding.word_embedding.weight"])
state_dict["trg_pos_embedding.pos_encoder.weight"] = np.float16(
state_dict["trg_pos_embedding.pos_encoder.weight"])
state_dict["decoding_linear.bias"] = np.zeros([self.trg_vocab_size],
dtype="float16")

self.load_dict(model_dict)
self.load_dict(state_dict)

if self.enable_faster_encoder:
self = enable_faster_encoder(self, use_fp16=self.use_fp16_encoder)
Expand Down Expand Up @@ -695,12 +696,18 @@ def forward(self, src_word, trg_word=None):
out = paddle.transpose(out, [1, 0, 2])
return out

def load(self, path):
def load(self, path=None, state_dict=None):
if path is None and state_dict is None:
raise ValueError(
"Either path or state_dict must be given to load the infer model. "
)

if isinstance(self.transformer, FasterTransformer):
self.transformer.load(path)
self.transformer.load(path, state_dict)
else:
model_dict = paddle.load(path)
self.transformer.load_dict(model_dict)
if state_dict is None:
state_dict = paddle.load(path)
self.transformer.load_dict(state_dict)


class FasterOPT(OPTPretrainedModel):
Expand Down